Repository: labstack/echo
Branch: master
Commit: 675712da34c1
Files: 127
Total size: 1.1 MB
Directory structure:
gitextract_q3zebxle/
├── .editorconfig
├── .gitattributes
├── .github/
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE.md
│ ├── stale.yml
│ └── workflows/
│ ├── checks.yml
│ └── echo.yml
├── .gitignore
├── API_CHANGES_V5.md
├── CHANGELOG.md
├── CLAUDE.md
├── LICENSE
├── Makefile
├── README.md
├── SECURITY.md
├── _fixture/
│ ├── _fixture/
│ │ └── README.md
│ ├── certs/
│ │ ├── README.md
│ │ ├── cert.pem
│ │ └── key.pem
│ ├── dist/
│ │ ├── private.txt
│ │ └── public/
│ │ ├── assets/
│ │ │ ├── readme.md
│ │ │ └── subfolder/
│ │ │ └── subfolder.md
│ │ ├── index.html
│ │ └── test.txt
│ ├── folder/
│ │ └── index.html
│ └── index.html
├── bind.go
├── bind_test.go
├── binder.go
├── binder_external_test.go
├── binder_generic.go
├── binder_generic_test.go
├── binder_test.go
├── codecov.yml
├── context.go
├── context_generic.go
├── context_generic_test.go
├── context_test.go
├── echo.go
├── echo_test.go
├── echotest/
│ ├── context.go
│ ├── context_external_test.go
│ ├── context_test.go
│ ├── reader.go
│ ├── reader_external_test.go
│ ├── reader_test.go
│ └── testdata/
│ └── test.json
├── go.mod
├── go.sum
├── group.go
├── group_test.go
├── httperror.go
├── httperror_external_test.go
├── httperror_test.go
├── ip.go
├── ip_test.go
├── json.go
├── json_test.go
├── middleware/
│ ├── DEVELOPMENT.md
│ ├── basic_auth.go
│ ├── basic_auth_test.go
│ ├── body_dump.go
│ ├── body_dump_test.go
│ ├── body_limit.go
│ ├── body_limit_test.go
│ ├── compress.go
│ ├── compress_test.go
│ ├── context_timeout.go
│ ├── context_timeout_test.go
│ ├── cors.go
│ ├── cors_test.go
│ ├── csrf.go
│ ├── csrf_test.go
│ ├── decompress.go
│ ├── decompress_test.go
│ ├── extractor.go
│ ├── extractor_test.go
│ ├── key_auth.go
│ ├── key_auth_test.go
│ ├── method_override.go
│ ├── method_override_test.go
│ ├── middleware.go
│ ├── middleware_test.go
│ ├── proxy.go
│ ├── proxy_test.go
│ ├── rate_limiter.go
│ ├── rate_limiter_test.go
│ ├── recover.go
│ ├── recover_test.go
│ ├── redirect.go
│ ├── redirect_test.go
│ ├── request_id.go
│ ├── request_id_test.go
│ ├── request_logger.go
│ ├── request_logger_test.go
│ ├── rewrite.go
│ ├── rewrite_test.go
│ ├── secure.go
│ ├── secure_test.go
│ ├── slash.go
│ ├── slash_test.go
│ ├── static.go
│ ├── static_other.go
│ ├── static_test.go
│ ├── testdata/
│ │ ├── dist/
│ │ │ ├── private.txt
│ │ │ └── public/
│ │ │ ├── assets/
│ │ │ │ ├── readme.md
│ │ │ │ └── subfolder/
│ │ │ │ └── subfolder.md
│ │ │ ├── index.html
│ │ │ └── test.txt
│ │ └── private.txt
│ ├── util.go
│ └── util_test.go
├── renderer.go
├── renderer_test.go
├── response.go
├── response_test.go
├── route.go
├── route_test.go
├── router.go
├── router_concurrent.go
├── router_concurrent_test.go
├── router_test.go
├── server.go
├── server_test.go
├── version.go
├── vhost.go
└── vhost_test.go
================================================
FILE CONTENTS
================================================
================================================
FILE: .editorconfig
================================================
# EditorConfig coding styles definitions. For more information about the
# properties used in this file, please see the EditorConfig documentation:
# http://editorconfig.org/
# indicate this is the root of the project
root = true
[*]
charset = utf-8
end_of_line = LF
insert_final_newline = true
trim_trailing_whitespace = true
indent_style = space
indent_size = 2
[Makefile]
indent_style = tab
[*.md]
trim_trailing_whitespace = false
[*.go]
indent_style = tab
================================================
FILE: .gitattributes
================================================
# Automatically normalize line endings for all text-based files
# http://git-scm.com/docs/gitattributes#_end_of_line_conversion
* text=auto
# For the following file types, normalize line endings to LF on checking and
# prevent conversion to CRLF when they are checked out (this is required in
# order to prevent newline related issues)
.* text eol=lf
*.go text eol=lf
*.yml text eol=lf
*.html text eol=lf
*.css text eol=lf
*.js text eol=lf
*.json text eol=lf
LICENSE text eol=lf
================================================
FILE: .github/FUNDING.yml
================================================
# These are supported funding model platforms
github: [labstack]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
================================================
FILE: .github/ISSUE_TEMPLATE.md
================================================
### Issue Description
### Working code to debug
```go
package main
import (
"github.com/labstack/echo/v5"
"net/http"
"net/http/httptest"
"testing"
)
func TestExample(t *testing.T) {
e := echo.New()
e.GET("/", func(c *echo.Context) error {
return c.String(http.StatusOK, "Hello, World!")
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("got %d, want %d", rec.Code, http.StatusOK)
}
}
```
### Version/commit
================================================
FILE: .github/stale.yml
================================================
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed
daysUntilClose: 30
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
- bug
- enhancement
# Label to use when marking an issue as stale
staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed within a month if no further activity occurs.
Thank you for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable
closeComment: false
================================================
FILE: .github/workflows/checks.yml
================================================
name: Run checks
on:
push:
branches:
- master
pull_request:
branches:
- master
workflow_dispatch:
permissions:
contents: read # to fetch code (actions/checkout)
env:
# run static analysis only with the latest Go version
LATEST_GO_VERSION: "1.26"
jobs:
check:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v5
- name: Set up Go ${{ matrix.go }}
uses: actions/setup-go@v5
with:
go-version: ${{ env.LATEST_GO_VERSION }}
check-latest: true
- name: Run golint
run: |
go install golang.org/x/lint/golint@latest
golint -set_exit_status ./...
- name: Run staticcheck
run: |
go install honnef.co/go/tools/cmd/staticcheck@latest
staticcheck ./...
- name: Run govulncheck
run: |
go version
go install golang.org/x/vuln/cmd/govulncheck@latest
govulncheck ./...
================================================
FILE: .github/workflows/echo.yml
================================================
name: Run Tests
on:
push:
branches:
- master
pull_request:
branches:
- master
workflow_dispatch:
permissions:
contents: read # to fetch code (actions/checkout)
env:
# run coverage and benchmarks only with the latest Go version
LATEST_GO_VERSION: "1.26"
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
# Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
# Echo tests with last four major releases (unless there are pressing vulnerabilities)
# As we depend on `golang.org/x/` libraries which only support the last 2 Go releases, we could have situations when
# we derive from the last four major releases promise.
go: ["1.25", "1.26"]
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Code
uses: actions/checkout@v5
- name: Set up Go ${{ matrix.go }}
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
- name: Run Tests
run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./...
- name: Upload coverage to Codecov
if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest'
uses: codecov/codecov-action@v5
with:
token:
fail_ci_if_error: false
benchmark:
needs: test
name: Benchmark comparison
runs-on: ubuntu-latest
steps:
- name: Checkout Code (Previous)
uses: actions/checkout@v5
with:
ref: ${{ github.base_ref }}
path: previous
- name: Checkout Code (New)
uses: actions/checkout@v5
with:
path: new
- name: Set up Go ${{ matrix.go }}
uses: actions/setup-go@v5
with:
go-version: ${{ env.LATEST_GO_VERSION }}
- name: Install Dependencies
run: go install golang.org/x/perf/cmd/benchstat@latest
- name: Run Benchmark (Previous)
run: |
cd previous
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchmark (New)
run: |
cd new
go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
- name: Run Benchstat
run: |
benchstat previous/benchmark.txt new/benchmark.txt
================================================
FILE: .gitignore
================================================
.DS_Store
coverage.txt
_test
vendor
.idea
*.iml
*.out
.vscode
================================================
FILE: API_CHANGES_V5.md
================================================
# Echo v5 Public API Changes
**Comparison between `master` (v4.15.0) and `v5` (v5.0.0-alpha) branches**
Generated: 2026-01-01
---
## Executive Summary (by authors)
Echo `v5` is maintenance release with **major breaking changes**
- `Context` is now struct instead of interface and we can add method to it in the future in minor versions.
- Adds new `Router` interface for possible new routing implementations.
- Drops old logging interface and uses moderm `log/slog` instead.
- Rearranges alot of methods/function signatures to make them more consistent.
## Executive Summary (by LLMs)
Echo v5 represents a **major breaking release** with significant architectural changes focused on:
- **Updated generic helpers** to take `*Context` and rename form helpers to `FormValue*`
- **Simplified API surface** by moving Context from interface to concrete struct
- **Modern Go patterns** including slog.Logger integration
- **Enhanced routing** with explicit RouteInfo and Routes types
- **Better error handling** with simplified HTTPError
- **New test helpers** via the `echotest` package
### Change Statistics
- **Major Breaking Changes**: 15+
- **New Functions Added**: 30+
- **Type Signature Changes**: 20+
- **Removed APIs**: 10+
- **New Packages Added**: 1 (`echotest`)
- **Version Change**: `4.15.0` → `5.0.0-alpha`
---
## Critical Breaking Changes
### 1. **Context: Interface → Concrete Struct**
**v4 (master):**
```go
type Context interface {
Request() *http.Request
// ... many methods
}
// Handler signature
func handler(c echo.Context) error
```
**v5:**
```go
type Context struct {
// Has unexported fields
}
// Handler signature - NOW USES POINTER!
func handler(c *echo.Context) error
```
**Impact:** 🔴 **CRITICAL BREAKING CHANGE**
- ALL handlers must change from `echo.Context` to `*echo.Context`
- Context is now a concrete struct, not an interface
- This affects every single handler function in user code
**Migration:**
```go
// Before (v4)
func MyHandler(c echo.Context) error {
return c.JSON(200, map[string]string{"hello": "world"})
}
// After (v5)
func MyHandler(c *echo.Context) error {
return c.JSON(200, map[string]string{"hello": "world"})
}
```
---
### 2. **Logger: Custom Interface → slog.Logger**
**v4:**
```go
type Echo struct {
Logger Logger // Custom interface with Print, Debug, Info, etc.
}
type Logger interface {
Output() io.Writer
SetOutput(w io.Writer)
Prefix() string
// ... many custom methods
}
// Context returns Logger interface
func (c Context) Logger() Logger
```
**v5:**
```go
type Echo struct {
Logger *slog.Logger // Standard library structured logger
}
// Context returns slog.Logger
func (c *Context) Logger() *slog.Logger
func (c *Context) SetLogger(logger *slog.Logger)
```
**Impact:** 🔴 **BREAKING CHANGE**
- Must use Go's standard `log/slog` package
- Logger interface completely removed
- All logging code needs updating
---
### 3. **Router: From Router to DefaultRouter**
**v4:**
```go
type Router struct { ... }
func NewRouter(e *Echo) *Router
func (e *Echo) Router() *Router
```
**v5:**
```go
type DefaultRouter struct { ... }
func NewRouter(config RouterConfig) *DefaultRouter
func (e *Echo) Router() Router // Returns interface
```
**Changes:**
- New `Router` interface introduced
- `DefaultRouter` is the concrete implementation
- `NewRouter()` now takes `RouterConfig` instead of `*Echo`
- Added `NewConcurrentRouter(r Router) Router` for thread-safe routing
---
### 4. **Route Return Types Changed**
**v4:**
```go
func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route
func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) []*Route
func (e *Echo) Routes() []*Route
```
**v5:**
```go
func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo
func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo
func (e *Echo) Match(...) Routes // Returns Routes type
func (e *Echo) Router() Router // Returns interface
```
**New Types:**
```go
type RouteInfo struct {
Name string
Method string
Path string
Parameters []string
}
type Routes []RouteInfo // Collection with helper methods
```
**Impact:** 🔴 **BREAKING CHANGE**
- Route registration methods return `RouteInfo` instead of `*Route`
- New `Routes` collection type with filtering methods
- `Route` struct still exists but used differently
---
### 5. **Response Type Changed**
**v4:**
```go
func (c Context) Response() *Response
type Response struct {
Writer http.ResponseWriter
Status int
Size int64
Committed bool
}
func NewResponse(w http.ResponseWriter, e *Echo) *Response
```
**v5:**
```go
func (c *Context) Response() http.ResponseWriter
type Response struct {
http.ResponseWriter // Embedded
Status int
Size int64
Committed bool
}
func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response
func UnwrapResponse(rw http.ResponseWriter) (*Response, error)
```
**Changes:**
- Context.Response() returns `http.ResponseWriter` instead of `*Response`
- Response now embeds `http.ResponseWriter`
- NewResponse takes `*slog.Logger` instead of `*Echo`
- New `UnwrapResponse()` helper function
---
### 6. **HTTPError Simplified**
**v4:**
```go
type HTTPError struct {
Internal error
Message interface{} // Can be any type
Code int
}
func NewHTTPError(code int, message ...interface{}) *HTTPError
```
**v5:**
```go
type HTTPError struct {
Code int
Message string // Now string only
// Has unexported fields (Internal moved)
}
func NewHTTPError(code int, message string) *HTTPError
func (he HTTPError) Wrap(err error) error // New method
func (he *HTTPError) StatusCode() int // Implements HTTPStatusCoder
```
**Changes:**
- `Message` field changed from `interface{}` to `string`
- `NewHTTPError()` now takes `string` instead of `...interface{}`
- Added `HTTPStatusCoder` interface and `StatusCode()` method
- Added `Wrap(err error)` method for error wrapping
---
### 7. **HTTPErrorHandler Signature Changed**
**v4:**
```go
type HTTPErrorHandler func(err error, c Context)
func (e *Echo) DefaultHTTPErrorHandler(err error, c Context)
```
**v5:**
```go
type HTTPErrorHandler func(c *Context, err error) // Parameters swapped!
func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler // Now a factory
```
**Impact:** 🔴 **BREAKING CHANGE**
- Parameter order reversed: `(c *Context, err error)` instead of `(err error, c Context)`
- DefaultHTTPErrorHandler is now a factory function that returns HTTPErrorHandler
- Takes `exposeError` bool to control error message exposure
---
## Notable API Changes in v5
### 1. **Generic Parameter Extraction Functions (Updated Signatures)**
These helpers keep the same generic API but now accept `*Context`, and the
form helpers are renamed from `FormParam*` to `FormValue*`:
```go
// Query Parameters
func QueryParam[T any](c *Context, key string, opts ...any) (T, error)
func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error)
func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error)
func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error)
// Path Parameters
func PathParam[T any](c *Context, paramName string, opts ...any) (T, error)
func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error)
// Form Values
func FormValue[T any](c *Context, key string, opts ...any) (T, error)
func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error)
func FormValues[T any](c *Context, key string, opts ...any) ([]T, error)
func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error)
// Generic Parsing
func ParseValue[T any](value string, opts ...any) (T, error)
func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error)
func ParseValues[T any](values []string, opts ...any) ([]T, error)
func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error)
```
`FormParam*` was renamed to `FormValue*`; the rest keep names but now take `*Context`.
**Supported Types:**
- bool, string
- int, int8, int16, int32, int64
- uint, uint8, uint16, uint32, uint64
- float32, float64
- time.Time, time.Duration
- BindUnmarshaler, encoding.TextUnmarshaler, json.Unmarshaler
**Example Usage:**
```go
// v5 - Type-safe parameter binding
id, err := echo.PathParam[int](c, "id")
page, err := echo.QueryParamOr[int](c, "page", 1)
tags, err := echo.QueryParams[string](c, "tags")
```
---
### 2. **Context Store Helpers Now Use `*Context`**
```go
// Type-safe context value retrieval
func ContextGet[T any](c *Context, key string) (T, error)
func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error)
// Error types
var ErrNonExistentKey = errors.New("non existent key")
var ErrInvalidKeyType = errors.New("invalid key type")
```
These helpers existed in v4 with `Context` and now accept `*Context`.
**Example:**
```go
// v5
user, err := echo.ContextGet[*User](c, "user")
count, err := echo.ContextGetOr[int](c, "count", 0)
```
---
### 3. **PathValues Type**
New structured path parameter handling:
```go
type PathValue struct {
Name string
Value string
}
type PathValues []PathValue
func (p PathValues) Get(name string) (string, bool)
func (p PathValues) GetOr(name string, defaultValue string) string
// Context methods
func (c *Context) PathValues() PathValues
func (c *Context) SetPathValues(pathValues PathValues)
```
---
### 4. **Time Parsing Options**
```go
type TimeLayout string
const (
TimeLayoutUnixTime = TimeLayout("UnixTime")
TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli")
TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano")
)
type TimeOpts struct {
Layout TimeLayout
ParseInLocation *time.Location
ToInLocation *time.Location
}
```
---
### 5. **StartConfig for Server Configuration**
```go
type StartConfig struct {
Address string
HideBanner bool
HidePort bool
CertFilesystem fs.FS
TLSConfig *tls.Config
ListenerNetwork string
ListenerAddrFunc func(addr net.Addr)
GracefulTimeout time.Duration
OnShutdownError func(err error)
BeforeServeFunc func(s *http.Server) error
}
func (sc StartConfig) Start(ctx context.Context, h http.Handler) error
func (sc StartConfig) StartTLS(ctx context.Context, h http.Handler, certFile, keyFile any) error
```
**Example:**
```go
// v5 - More control over server startup
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer cancel()
sc := echo.StartConfig{
Address: ":8080",
GracefulTimeout: 10 * time.Second,
}
if err := sc.Start(ctx, e); err != nil {
log.Fatal(err)
}
```
---
### 6. **Echo Config and Constructors**
```go
type Config struct {
// Configuration for Echo (logger, binder, renderer, etc.)
}
func NewWithConfig(config Config) *Echo
```
This adds a configuration struct for creating an `Echo` instance without
mutating fields after `New()`.
---
### 7. **Enhanced Routing Features**
```go
// New route methods
func (e *Echo) AddRoute(route Route) (RouteInfo, error)
func (e *Echo) Middlewares() []MiddlewareFunc
func (e *Echo) PreMiddlewares() []MiddlewareFunc
type AddRouteError struct{ ... }
// Routes collection with filters
type Routes []RouteInfo
func (r Routes) Clone() Routes
func (r Routes) FilterByMethod(method string) (Routes, error)
func (r Routes) FilterByName(name string) (Routes, error)
func (r Routes) FilterByPath(path string) (Routes, error)
func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error)
func (r Routes) Reverse(routeName string, pathValues ...any) (string, error)
// RouteInfo operations
func (r RouteInfo) Clone() RouteInfo
func (r RouteInfo) Reverse(pathValues ...any) string
```
---
### 8. **Middleware Configuration Interface**
```go
type MiddlewareConfigurator interface {
ToMiddleware() (MiddlewareFunc, error)
}
```
Allows middleware configs to be converted to middleware without panicking.
---
### 9. **New Context Methods**
```go
// v5 additions
func (c *Context) FileFS(file string, filesystem fs.FS) error
func (c *Context) FormValueOr(name, defaultValue string) string
func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues)
func (c *Context) ParamOr(name, defaultValue string) string
func (c *Context) QueryParamOr(name, defaultValue string) string
func (c *Context) RouteInfo() RouteInfo
```
---
### 10. **Virtual Host Support**
```go
func NewVirtualHostHandler(vhosts map[string]*Echo) *Echo
```
Creates an Echo instance that routes requests to different Echo instances based on host.
---
### 11. **New Binder Functions**
```go
func BindBody(c *Context, target any) error
func BindHeaders(c *Context, target any) error
func BindPathValues(c *Context, target any) error // Renamed from BindPathParams
func BindQueryParams(c *Context, target any) error
```
Top-level binding functions that work with `*Context`.
---
### 12. **New echotest Package**
```go
package echotest // import "github.com/labstack/echo/v5/echotest"
func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte
func TrimNewlineEnd(bytes []byte) []byte
type ContextConfig struct{ ... }
type MultipartForm struct{ ... }
type MultipartFormFile struct{ ... }
```
Helpers for loading fixtures and constructing test contexts.
---
## Removed APIs in v5
### Constants
```go
// v4 - Removed in v5
const CONNECT = http.MethodConnect // Use http.MethodConnect directly
```
**Reason:** Deprecated in v4, use stdlib `http.Method*` constants instead.
---
### Constants Added in v5
```go
// v5 additions
const (
NotFoundRouteName = "echo_route_not_found_name"
)
```
---
### Error Variable Changes
**v4 exports:**
```go
ErrBadRequest
ErrInvalidKeyType
ErrNonExistentKey
```
**v5 exports:**
```go
ErrBadRequest // Now backed by unexported httpError type
ErrValidatorNotRegistered // New
ErrInvalidKeyType
ErrNonExistentKey
```
**Reason:** v5 centralizes on `NewHTTPError(code, message)` rather than a broad set
of predefined HTTP error variables.
---
### Functions Removed
```go
// v4 - Removed in v5
func GetPath(r *http.Request) string // Use r.URL.Path or r.URL.RawPath
```
### Variables Removed
```go
// v4 - Removed in v5
var MethodNotAllowedHandler = func(c Context) error { ... }
var NotFoundHandler = func(c Context) error { ... }
```
### Functions Renamed
```go
// v4
func FormParam[T any](c Context, key string, opts ...any) (T, error)
func FormParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error)
func FormParams[T any](c Context, key string, opts ...any) ([]T, error)
func FormParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error)
// v5
func FormValue[T any](c *Context, key string, opts ...any) (T, error)
func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error)
func FormValues[T any](c *Context, key string, opts ...any) ([]T, error)
func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error)
```
---
### Type Methods Removed/Changed
**Echo struct changes:**
```go
// v4 fields removed in v5
type Echo struct {
StdLogger *stdLog.Logger // Removed
Server *http.Server // Removed (use StartConfig)
TLSServer *http.Server // Removed (use StartConfig)
Listener net.Listener // Removed (use StartConfig)
TLSListener net.Listener // Removed (use StartConfig)
AutoTLSManager autocert.Manager // Removed
ListenerNetwork string // Removed
OnAddRouteHandler func(...) // Changed to OnAddRoute
DisableHTTP2 bool // Removed (use StartConfig)
Debug bool // Removed
HideBanner bool // Removed (use StartConfig)
HidePort bool // Removed (use StartConfig)
}
// v5 Echo struct (simplified)
type Echo struct {
Binder Binder
Filesystem fs.FS // NEW
Renderer Renderer
Validator Validator
JSONSerializer JSONSerializer
IPExtractor IPExtractor
OnAddRoute func(route Route) error // Simplified
HTTPErrorHandler HTTPErrorHandler
Logger *slog.Logger // Changed from Logger interface
}
```
---
**Context interface → struct:**
```go
// v4
type Context interface {
// Had: SetResponse(*Response)
Response() *Response
// Had: ParamNames(), SetParamNames(), ParamValues(), SetParamValues()
// These are removed in v5 (use PathValues() instead)
}
// v5
type Context struct {
// Concrete struct with unexported fields
}
func (c *Context) Response() http.ResponseWriter // Changed return type
func (c *Context) PathValues() PathValues // Replaces ParamNames/Values
```
---
**Types removed:**
```go
// v4
type Map map[string]interface{}
```
**Group changes:**
```go
// v4
func (g *Group) File(path, file string) // No return value
func (g *Group) Static(pathPrefix, fsRoot string) // No return value
func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) // No return value
// v5
func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo
func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo
func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo
```
Now return `RouteInfo` and accept middleware.
---
### Value Binder Factory Name Changes
```go
// v4
func PathParamsBinder(c Context) *ValueBinder
func QueryParamsBinder(c Context) *ValueBinder
func FormFieldBinder(c Context) *ValueBinder
// v5
func PathValuesBinder(c *Context) *ValueBinder // Renamed
func QueryParamsBinder(c *Context) *ValueBinder
func FormFieldBinder(c *Context) *ValueBinder
```
---
## Type Signature Changes
### Binder Interface
```go
// v4
type Binder interface {
Bind(i interface{}, c Context) error
}
// v5
type Binder interface {
Bind(c *Context, target any) error // Parameters swapped!
}
```
---
### DefaultBinder Methods
```go
// v4
func (b *DefaultBinder) Bind(i interface{}, c Context) error
func (b *DefaultBinder) BindBody(c Context, i interface{}) error
func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error
// v5
func (b *DefaultBinder) Bind(c *Context, target any) error // Swapped params
// BindBody, BindPathParams, etc. are now top-level functions
```
---
### JSONSerializer Interface
```go
// v4
type JSONSerializer interface {
Serialize(c Context, i interface{}, indent string) error
Deserialize(c Context, i interface{}) error
}
// v5
type JSONSerializer interface {
Serialize(c *Context, target any, indent string) error
Deserialize(c *Context, target any) error
}
```
---
### Renderer Interface
```go
// v4
type Renderer interface {
Render(io.Writer, string, interface{}, Context) error
}
// v5
type Renderer interface {
Render(c *Context, w io.Writer, templateName string, data any) error
}
```
Parameters reordered with Context first.
---
### NewBindingError
```go
// v4
func NewBindingError(sourceParam string, values []string, message interface{}, internalError error) error
// v5
func NewBindingError(sourceParam string, values []string, message string, err error) error
```
Message parameter changed from `interface{}` to `string`.
---
### HandlerName
```go
// v5 only
func HandlerName(h HandlerFunc) string
```
New utility function to get handler function name.
---
## Middleware Package Changes
### Signature and Type Updates
```go
// CORS now accepts optional allow-origins
func CORS(allowOrigins ...string) echo.MiddlewareFunc
// BodyLimit now accepts bytes
func BodyLimit(limitBytes int64) echo.MiddlewareFunc
// DefaultSkipper now uses *echo.Context
func DefaultSkipper(c *echo.Context) bool
// Trailing slash configs renamed/split
func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc
func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc
type AddTrailingSlashConfig struct{ ... }
type RemoveTrailingSlashConfig struct{ ... }
// Auth + extractor signatures now use *echo.Context and add ExtractorSource
type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error)
type Extractor func(c *echo.Context) (string, error)
type ExtractorSource string
type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error)
type KeyAuthErrorHandler func(c *echo.Context, err error) error
// BodyDump handler now includes err
type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error)
// ValuesExtractor now returns extractor source and CreateExtractors takes a limit
type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error)
func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error)
type ValueExtractorError struct{ ... }
// New constants
const KB = 1024
// Rate limiter store now takes a float64 limit
func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore)
```
### Added Middleware Exports
```go
var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
var RedirectHTTPSConfig = RedirectConfig{ ... }
var RedirectHTTPSWWWConfig = RedirectConfig{ ... }
var RedirectNonHTTPSWWWConfig = RedirectConfig{ ... }
var RedirectNonWWWConfig = RedirectConfig{ ... }
var RedirectWWWConfig = RedirectConfig{ ... }
```
### Removed/Consolidated Middleware Exports
```go
// Removed in v5
func Logger() echo.MiddlewareFunc
func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc
func Timeout() echo.MiddlewareFunc
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc
type ErrKeyAuthMissing struct{ ... }
type CSRFErrorHandler func(err error, c echo.Context) error
type LoggerConfig struct{ ... }
type LogErrorFunc func(c echo.Context, err error, stack []byte) error
type TargetProvider interface{ ... }
type TrailingSlashConfig struct{ ... }
type TimeoutConfig struct{ ... }
```
Also removed defaults: `DefaultBasicAuthConfig`, `DefaultBodyDumpConfig`, `DefaultBodyLimitConfig`,
`DefaultCORSConfig`, `DefaultDecompressConfig`, `DefaultGzipConfig`, `DefaultLoggerConfig`,
`DefaultRedirectConfig`, `DefaultRequestIDConfig`, `DefaultRewriteConfig`, `DefaultTimeoutConfig`,
`DefaultTrailingSlashConfig`.
---
## Router Interface Changes
### v4 Router (Concrete Struct)
```go
type Router struct { ... }
func NewRouter(e *Echo) *Router
func (r *Router) Add(method, path string, h HandlerFunc)
func (r *Router) Find(method, path string, c Context)
func (r *Router) Reverse(name string, params ...interface{}) string
func (r *Router) Routes() []*Route
```
### v5 Router (Interface + DefaultRouter)
```go
type Router interface {
Add(routable Route) (RouteInfo, error)
Remove(method string, path string) error
Routes() Routes
Route(c *Context) HandlerFunc
}
type DefaultRouter struct { ... }
func NewRouter(config RouterConfig) *DefaultRouter
func NewConcurrentRouter(r Router) Router // NEW
type RouterConfig struct {
NotFoundHandler HandlerFunc
MethodNotAllowedHandler HandlerFunc
OptionsMethodHandler HandlerFunc
AllowOverwritingRoute bool
UnescapePathParamValues bool
UseEscapedPathForMatching bool
}
```
**Key Changes:**
- Router is now an interface
- DefaultRouter is the concrete implementation
- Add() returns `(RouteInfo, error)` instead of being void
- New `Remove()` method
- New `Route()` method replaces `Find()`
- Configuration through `RouterConfig`
---
## Echo Instance Method Changes
### Route Registration
```go
// v4
func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route
// v5
func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo
func (e *Echo) AddRoute(route Route) (RouteInfo, error) // NEW
```
### Static File Serving
```go
// v4
func (e *Echo) Static(pathPrefix, fsRoot string) *Route
func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route
func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route
func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route
// v5
func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo
func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo
func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo
func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo
```
Return type changed from `*Route` to `RouteInfo`.
### Server Management
```go
// v4
func (e *Echo) Start(address string) error
func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) error
func (e *Echo) StartAutoTLS(address string) error
func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error
func (e *Echo) StartServer(s *http.Server) error
func (e *Echo) Shutdown(ctx context.Context) error
func (e *Echo) Close() error
func (e *Echo) ListenerAddr() net.Addr
func (e *Echo) TLSListenerAddr() net.Addr
func (e *Echo) DefaultHTTPErrorHandler(err error, c Context)
// v5
func (e *Echo) Start(address string) error // Simplified
func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request)
// Removed: StartTLS, StartAutoTLS, StartH2CServer, StartServer
// Use StartConfig instead for advanced server configuration
// Removed: Shutdown, Close, ListenerAddr, TLSListenerAddr
// Removed: DefaultHTTPErrorHandler (now a top-level factory function)
```
**v5 provides** `StartConfig` type for all advanced server configuration.
### Router Access
```go
// v4
func (e *Echo) Router() *Router
func (e *Echo) Routers() map[string]*Router // For multi-host
func (e *Echo) Routes() []*Route
func (e *Echo) Reverse(name string, params ...interface{}) string
func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string
func (e *Echo) URL(h HandlerFunc, params ...interface{}) string
func (e *Echo) Host(name string, m ...MiddlewareFunc) *Group
// v5
func (e *Echo) Router() Router // Returns interface
// Removed: Routers(), Reverse(), URI(), URL(), Host()
// Use router.Routes() and Routes.Reverse() instead
```
---
## NewContext Changes
```go
// v4
func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context
func NewResponse(w http.ResponseWriter, e *Echo) *Response
// v5
func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context
func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context // Standalone
func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response
```
---
## Migration Guide Summary
If you are using Linux you can migrate easier parts like that:
```bash
find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} +
find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} +
```
or in your favorite IDE
Replace all:
1. ` echo.Context` -> ` *echo.Context`
2. `echo/v4` -> `echo/v5`
### 1. Update All Handler Signatures
```go
// Before
func MyHandler(c echo.Context) error { ... }
// After
func MyHandler(c *echo.Context) error { ... }
```
### 2. Update Logger Usage
```go
// Before
e.Logger.Info("Server started")
c.Logger().Error("Something went wrong")
// After
e.Logger.Info("Server started")
c.Logger().Error("Something went wrong") // Same API, different logger
```
### 3. Use Type-Safe Parameter Extraction
```go
// Before
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
// After
id, err := echo.PathParam[int](c, "id")
```
### 4. Update Error Handler
```go
// Before
e.HTTPErrorHandler = func(err error, c echo.Context) {
// handle error
}
// After
e.HTTPErrorHandler = func(c *echo.Context, err error) { // Swapped!
// handle error
}
// Or use factory
e.HTTPErrorHandler = echo.DefaultHTTPErrorHandler(true) // exposeError=true
```
### 5. Update Server Startup
```go
// Before
e.Start(":8080")
e.StartTLS(":443", "cert.pem", "key.pem")
// After
// Simple
e.Start(":8080")
// Advanced with graceful shutdown
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
sc := echo.StartConfig{Address: ":8080"}
sc.Start(ctx, e)
```
### 6. Update Route Info Access
```go
// Before
routes := e.Routes()
for _, r := range routes {
fmt.Println(r.Method, r.Path)
}
// After
routes := e.Router().Routes()
for _, r := range routes {
fmt.Println(r.Method, r.Path)
}
```
### 7. Update HTTPError Creation
```go
// Before
return echo.NewHTTPError(400, "invalid request", someDetail)
// After
return echo.NewHTTPError(400, "invalid request")
```
### 8. Update Custom Binder
```go
// Before
type MyBinder struct{}
func (b *MyBinder) Bind(i interface{}, c echo.Context) error { ... }
// After
type MyBinder struct{}
func (b *MyBinder) Bind(c *echo.Context, target any) error { ... } // Swapped!
```
### 9. Path Parameters
```go
// Before
names := c.ParamNames()
values := c.ParamValues()
// After
pathValues := c.PathValues()
for _, pv := range pathValues {
fmt.Println(pv.Name, pv.Value)
}
```
### 10. Response Access
```go
// Before
resp := c.Response()
resp.Header().Set("X-Custom", "value")
// After
c.Response().Header().Set("X-Custom", "value") // Returns http.ResponseWriter
// To get *echo.Response
resp, err := echo.UnwrapResponse(c.Response())
```
### Go Version Requirements
- **v4**: Go 1.24.0 (per `go.mod`)
- **v5**: Go 1.25.0 (per `go.mod`)
---
**Generated by comparing `go doc` output from master (v4.15.0) and v5 (v5.0.0-alpha) branches**
================================================
FILE: CHANGELOG.md
================================================
# Changelog
## v5.0.4 - 2026-02-15
**Enhancements**
* Remove unused import 'errors' from README example by @kumapower17 in https://github.com/labstack/echo/pull/2889
* Fix Graceful shutdown: after `http.Server.Serve` returns we need to wait for graceful shutdown goroutine to finish by @aldas in https://github.com/labstack/echo/pull/2898
* Update location of oapi-codegen in README by @mromaszewicz in https://github.com/labstack/echo/pull/2896
* Add Go 1.26 to CI flow by @aldas in https://github.com/labstack/echo/pull/2899
* Add new function `echo.StatusCode` by @suwakei in https://github.com/labstack/echo/pull/2892
* CSRF: support older token-based CSRF protection handler that want to render token into template by @aldas in https://github.com/labstack/echo/pull/2894
* Add `echo.ResolveResponseStatus` function to help middleware/handlers determine HTTP status code and echo.Response by @aldas in https://github.com/labstack/echo/pull/2900
## v5.0.3 - 2026-02-06
**Security**
* Fix directory traversal vulnerability under Windows in Static middleware when default Echo filesystem is used. Reported by @shblue21.
This applies to cases when:
- Windows is used as OS
- `middleware.StaticConfig.Filesystem` is `nil` (default)
- `echo.Filesystem` is has not been set explicitly (default)
Exposure is restricted to the active process working directory and its subfolders.
## v5.0.2 - 2026-02-02
**Security**
* Fix Static middleware with `config.Browse=true` lists all files/subfolders from `config.Filesystem` root and not starting from `config.Root` in https://github.com/labstack/echo/pull/2887
## v5.0.1 - 2026-01-28
* Panic MW: will now return a custom PanicStackError with stack trace by @aldas in https://github.com/labstack/echo/pull/2871
* Docs: add missing err parameter to DenyHandler example by @cgalibern in https://github.com/labstack/echo/pull/2878
* improve: improve websocket checks in IsWebSocket() [per RFC 6455] by @raju-mechatronics in https://github.com/labstack/echo/pull/2875
* fix: Context.Json() should not send status code before serialization is complete by @aldas in https://github.com/labstack/echo/pull/2877
## v5.0.0 - 2026-01-18
Echo `v5` is maintenance release with **major breaking changes**
- `Context` is now struct instead of interface and we can add method to it in the future in minor versions.
- Adds new `Router` interface for possible new routing implementations.
- Drops old logging interface and uses moderm `log/slog` instead.
- Rearranges alot of methods/function signatures to make them more consistent.
Upgrade notes and `v4` support:
- Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31**
- If you are using Echo in a production environment, it is recommended to wait until after 2026-03-31 before upgrading.
- Until 2026-03-31, any critical issues requiring breaking `v5` API changes will be addressed, even if this violates semantic versioning.
See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on **upgrading**.
Upgrading TLDR:
If you are using Linux you can migrate easier parts like that:
```bash
find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} +
find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} +
```
macOS
```bash
find . -type f -name "*.go" -exec sed -i '' 's/ echo.Context/ *echo.Context/g' {} +
find . -type f -name "*.go" -exec sed -i '' 's/echo\/v4/echo\/v5/g' {} +
```
or in your favorite IDE
Replace all:
1. ` echo.Context` -> ` *echo.Context`
2. `echo/v4` -> `echo/v5`
This should solve most of the issues. Probably the hardest part is updating all the tests.
## v4.15.0 - 2026-01-01
**Security**
NB: **If your application relies on cross-origin or same-site (same subdomain) requests do not blindly push this version to production**
The CSRF middleware now supports the [**Sec-Fetch-Site**](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site) header as a modern, defense-in-depth approach to [CSRF
protection](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers), implementing the OWASP-recommended Fetch Metadata API alongside the traditional token-based mechanism.
**How it works:**
Modern browsers automatically send the `Sec-Fetch-Site` header with all requests, indicating the relationship
between the request origin and the target. The middleware uses this to make security decisions:
- **`same-origin`** or **`none`**: Requests are allowed (exact origin match or direct user navigation)
- **`same-site`**: Falls back to token validation (e.g., subdomain to main domain)
- **`cross-site`**: Blocked by default with 403 error for unsafe methods (POST, PUT, DELETE, PATCH)
For browsers that don't send this header (older browsers), the middleware seamlessly falls back to
traditional token-based CSRF protection.
**New Configuration Options:**
- `TrustedOrigins []string`: Allowlist specific origins for cross-site requests (useful for OAuth callbacks, webhooks)
- `AllowSecFetchSiteFunc func(echo.Context) (bool, error)`: Custom logic for same-site/cross-site request validation
**Example:**
```go
e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
// Allow OAuth callbacks from trusted provider
TrustedOrigins: []string{"https://oauth-provider.com"},
// Custom validation for same-site requests
AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) {
// Your custom authorization logic here
return validateCustomAuth(c), nil
// return true, err // blocks request with error
// return true, nil // allows CSRF request through
// return false, nil // falls back to legacy token logic
},
}))
```
PR: https://github.com/labstack/echo/pull/2858
**Type-Safe Generic Parameter Binding**
* Added generic functions for type-safe parameter extraction and context access by @aldas in https://github.com/labstack/echo/pull/2856
Echo now provides generic functions for extracting path, query, and form parameters with automatic type conversion,
eliminating manual string parsing and type assertions.
**New Functions:**
- Path parameters: `PathParam[T]`, `PathParamOr[T]`
- Query parameters: `QueryParam[T]`, `QueryParamOr[T]`, `QueryParams[T]`, `QueryParamsOr[T]`
- Form values: `FormParam[T]`, `FormParamOr[T]`, `FormParams[T]`, `FormParamsOr[T]`
- Context store: `ContextGet[T]`, `ContextGetOr[T]`
**Supported Types:**
Primitives (`bool`, `string`, `int`/`uint` variants, `float32`/`float64`), `time.Duration`, `time.Time`
(with custom layouts and Unix timestamp support), and custom types implementing `BindUnmarshaler`,
`TextUnmarshaler`, or `JSONUnmarshaler`.
**Example:**
```go
// Before: Manual parsing
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
// After: Type-safe with automatic parsing
id, err := echo.PathParam[int](c, "id")
// With default values
page, err := echo.QueryParamOr[int](c, "page", 1)
limit, err := echo.QueryParamOr[int](c, "limit", 20)
// Type-safe context access (no more panics from type assertions)
user, err := echo.ContextGet[*User](c, "user")
```
PR: https://github.com/labstack/echo/pull/2856
**DEPRECATION NOTICE** Timeout Middleware Deprecated - Use ContextTimeout Instead
The `middleware.Timeout` middleware has been **deprecated** due to fundamental architectural issues that cause
data races. Use `middleware.ContextTimeout` or `middleware.ContextTimeoutWithConfig` instead.
**Why is this being deprecated?**
The Timeout middleware manipulates response writers across goroutine boundaries, which causes data races that
cannot be reliably fixed without a complete architectural redesign. The middleware:
- Swaps the response writer using `http.TimeoutHandler`
- Must be the first middleware in the chain (fragile constraint)
- Can cause races with other middleware (Logger, metrics, custom middleware)
- Has been the source of multiple race condition fixes over the years
**What should you use instead?**
The `ContextTimeout` middleware (available since v4.12.0) provides timeout functionality using Go's standard
context mechanism. It is:
- Race-free by design
- Can be placed anywhere in the middleware chain
- Simpler and more maintainable
- Compatible with all other middleware
**Migration Guide:**
```go
// Before (deprecated):
e.Use(middleware.Timeout())
// After (recommended):
e.Use(middleware.ContextTimeout(30 * time.Second))
```
**Important Behavioral Differences:**
1. **Handler cooperation required**: With ContextTimeout, your handlers must check `context.Done()` for cooperative
cancellation. The old Timeout middleware would send a 503 response regardless of handler cooperation, but had
data race issues.
2. **Error handling**: ContextTimeout returns errors through the standard error handling flow. Handlers that receive
`context.DeadlineExceeded` should handle it appropriately:
```go
e.GET("/long-task", func(c echo.Context) error {
ctx := c.Request().Context()
// Example: database query with context
result, err := db.QueryContext(ctx, "SELECT * FROM large_table")
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
// Handle timeout
return echo.NewHTTPError(http.StatusServiceUnavailable, "Request timeout")
}
return err
}
return c.JSON(http.StatusOK, result)
})
```
3. **Background tasks**: For long-running background tasks, use goroutines with context:
```go
e.GET("/async-task", func(c echo.Context) error {
ctx := c.Request().Context()
resultCh := make(chan Result, 1)
errCh := make(chan error, 1)
go func() {
result, err := performLongTask(ctx)
if err != nil {
errCh <- err
return
}
resultCh <- result
}()
select {
case result := <-resultCh:
return c.JSON(http.StatusOK, result)
case err := <-errCh:
return err
case <-ctx.Done():
return echo.NewHTTPError(http.StatusServiceUnavailable, "Request timeout")
}
})
```
**Enhancements**
* Fixes by @aldas in https://github.com/labstack/echo/pull/2852
* Generic functions by @aldas in https://github.com/labstack/echo/pull/2856
* CRSF with Sec-Fetch-Site checks by @aldas in https://github.com/labstack/echo/pull/2858
## v4.14.0 - 2025-12-11
`middleware.Logger` has been deprecated. For request logging, use `middleware.RequestLogger` or
`middleware.RequestLoggerWithConfig`.
`middleware.RequestLogger` replaces `middleware.Logger`, offering comparable configuration while relying on the
Go standard library’s new `slog` logger.
The previous default output format was JSON. The new default follows the standard `slog` logger settings.
To continue emitting request logs in JSON, configure `slog` accordingly:
```go
slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil)))
e.Use(middleware.RequestLogger())
```
**Security**
* Logger middleware json string escaping and deprecation by @aldas in https://github.com/labstack/echo/pull/2849
**Enhancements**
* Update deps by @aldas in https://github.com/labstack/echo/pull/2807
* refactor to use reflect.TypeFor by @cuiweixie in https://github.com/labstack/echo/pull/2812
* Use Go 1.25 in CI by @aldas in https://github.com/labstack/echo/pull/2810
* Modernize context.go by replacing interface{} with any by @vishr in https://github.com/labstack/echo/pull/2822
* Fix typo in SetParamValues comment by @vishr in https://github.com/labstack/echo/pull/2828
* Fix typo in ContextTimeout middleware comment by @vishr in https://github.com/labstack/echo/pull/2827
* Improve BasicAuth middleware: use strings.Cut and RFC compliance by @vishr in https://github.com/labstack/echo/pull/2825
* Fix duplicate plus operator in router backtracking logic by @yuya-morimoto in https://github.com/labstack/echo/pull/2832
* Replace custom private IP range check with built-in net.IP.IsPrivate by @kumapower17 in https://github.com/labstack/echo/pull/2835
* Ensure proxy connection is closed in proxyRaw function(#2837) by @kumapower17 in https://github.com/labstack/echo/pull/2838
* Update deps by @aldas in https://github.com/labstack/echo/pull/2843
* Update golang.org/x/* deps by @aldas in https://github.com/labstack/echo/pull/2850
## v4.13.4 - 2025-05-22
**Enhancements**
* chore: fix some typos in comment by @zhuhaicity in https://github.com/labstack/echo/pull/2735
* CI: test with Go 1.24 by @aldas in https://github.com/labstack/echo/pull/2748
* Add support for TLS WebSocket proxy by @t-ibayashi-safie in https://github.com/labstack/echo/pull/2762
**Security**
* Update dependencies for [GO-2025-3487](https://pkg.go.dev/vuln/GO-2025-3487), [GO-2025-3503](https://pkg.go.dev/vuln/GO-2025-3503) and [GO-2025-3595](https://pkg.go.dev/vuln/GO-2025-3595) in https://github.com/labstack/echo/pull/2780
## v4.13.3 - 2024-12-19
**Security**
* Update golang.org/x/net dependency [GO-2024-3333](https://pkg.go.dev/vuln/GO-2024-3333) in https://github.com/labstack/echo/pull/2722
## v4.13.2 - 2024-12-12
**Security**
* Update dependencies (dependabot reports [GO-2024-3321](https://pkg.go.dev/vuln/GO-2024-3321)) in https://github.com/labstack/echo/pull/2721
## v4.13.1 - 2024-12-11
**Fixes**
* Fix BindBody ignoring `Transfer-Encoding: chunked` requests by @178inaba in https://github.com/labstack/echo/pull/2717
## v4.13.0 - 2024-12-04
**BREAKING CHANGE** JWT Middleware Removed from Core use [labstack/echo-jwt](https://github.com/labstack/echo-jwt) instead
The JWT middleware has been **removed from Echo core** due to another security vulnerability, [CVE-2024-51744](https://nvd.nist.gov/vuln/detail/CVE-2024-51744). For more details, refer to issue [#2699](https://github.com/labstack/echo/issues/2699). A drop-in replacement is available in the [labstack/echo-jwt](https://github.com/labstack/echo-jwt) repository.
**Important**: Direct assignments like `token := c.Get("user").(*jwt.Token)` will now cause a panic due to an invalid cast. Update your code accordingly. Replace the current imports from `"github.com/golang-jwt/jwt"` in your handlers to the new middleware version using `"github.com/golang-jwt/jwt/v5"`.
Background:
The version of `golang-jwt/jwt` (v3.2.2) previously used in Echo core has been in an unmaintained state for some time. This is not the first vulnerability affecting this library; earlier issues were addressed in [PR #1946](https://github.com/labstack/echo/pull/1946).
JWT middleware was marked as deprecated in Echo core as of [v4.10.0](https://github.com/labstack/echo/releases/tag/v4.10.0) on 2022-12-27. If you did not notice that, consider leveraging tools like [Staticcheck](https://staticcheck.dev/) to catch such deprecations earlier in you dev/CI flow. For bonus points - check out [gosec](https://github.com/securego/gosec).
We sincerely apologize for any inconvenience caused by this change. While we strive to maintain backward compatibility within Echo core, recurring security issues with third-party dependencies have forced this decision.
**Enhancements**
* remove jwt middleware by @stevenwhitehead in https://github.com/labstack/echo/pull/2701
* optimization: struct alignment by @behnambm in https://github.com/labstack/echo/pull/2636
* bind: Maintain backwards compatibility for map[string]interface{} binding by @thesaltree in https://github.com/labstack/echo/pull/2656
* Add Go 1.23 to CI by @aldas in https://github.com/labstack/echo/pull/2675
* improve `MultipartForm` test by @martinyonatann in https://github.com/labstack/echo/pull/2682
* `bind` : add support of multipart multi files by @martinyonatann in https://github.com/labstack/echo/pull/2684
* Add TemplateRenderer struct to ease creating renderers for `html/template` and `text/template` packages. by @aldas in https://github.com/labstack/echo/pull/2690
* Refactor TestBasicAuth to utilize table-driven test format by @ErikOlson in https://github.com/labstack/echo/pull/2688
* Remove broken header by @aldas in https://github.com/labstack/echo/pull/2705
* fix(bind body): content-length can be -1 by @phamvinhdat in https://github.com/labstack/echo/pull/2710
* CORS middleware should compile allowOrigin regexp at creation by @aldas in https://github.com/labstack/echo/pull/2709
* Shorten Github issue template and add test example by @aldas in https://github.com/labstack/echo/pull/2711
## v4.12.0 - 2024-04-15
**Security**
* Update golang.org/x/net dep because of [GO-2024-2687](https://pkg.go.dev/vuln/GO-2024-2687) by @aldas in https://github.com/labstack/echo/pull/2625
**Enhancements**
* binder: make binding to Map work better with string destinations by @aldas in https://github.com/labstack/echo/pull/2554
* README.md: add Encore as sponsor by @marcuskohlberg in https://github.com/labstack/echo/pull/2579
* Reorder paragraphs in README.md by @aldas in https://github.com/labstack/echo/pull/2581
* CI: upgrade actions/checkout to v4 by @aldas in https://github.com/labstack/echo/pull/2584
* Remove default charset from 'application/json' Content-Type header by @doortts in https://github.com/labstack/echo/pull/2568
* CI: Use Go 1.22 by @aldas in https://github.com/labstack/echo/pull/2588
* binder: allow binding to a nil map by @georgmu in https://github.com/labstack/echo/pull/2574
* Add Skipper Unit Test In BasicBasicAuthConfig and Add More Detail Explanation regarding BasicAuthValidator by @RyoKusnadi in https://github.com/labstack/echo/pull/2461
* fix some typos by @teslaedison in https://github.com/labstack/echo/pull/2603
* fix: some typos by @pomadev in https://github.com/labstack/echo/pull/2596
* Allow ResponseWriters to unwrap writers when flushing/hijacking by @aldas in https://github.com/labstack/echo/pull/2595
* Add SPDX licence comments to files. by @aldas in https://github.com/labstack/echo/pull/2604
* Upgrade deps by @aldas in https://github.com/labstack/echo/pull/2605
* Change type definition blocks to single declarations. This helps copy… by @aldas in https://github.com/labstack/echo/pull/2606
* Fix Real IP logic by @cl-bvl in https://github.com/labstack/echo/pull/2550
* Default binder can use `UnmarshalParams(params []string) error` inter… by @aldas in https://github.com/labstack/echo/pull/2607
* Default binder can bind pointer to slice as struct field. For example `*[]string` by @aldas in https://github.com/labstack/echo/pull/2608
* Remove maxparam dependence from Context by @aldas in https://github.com/labstack/echo/pull/2611
* When route is registered with empty path it is normalized to `/`. by @aldas in https://github.com/labstack/echo/pull/2616
* proxy middleware should use httputil.ReverseProxy for SSE requests by @aldas in https://github.com/labstack/echo/pull/2624
## v4.11.4 - 2023-12-20
**Security**
* Upgrade golang.org/x/crypto to v0.17.0 to fix vulnerability [issue](https://pkg.go.dev/vuln/GO-2023-2402) [#2562](https://github.com/labstack/echo/pull/2562)
**Enhancements**
* Update deps and mark Go version to 1.18 as this is what golang.org/x/* use [#2563](https://github.com/labstack/echo/pull/2563)
* Request logger: add example for Slog https://pkg.go.dev/log/slog [#2543](https://github.com/labstack/echo/pull/2543)
## v4.11.3 - 2023-11-07
**Security**
* 'c.Attachment' and 'c.Inline' should escape filename in 'Content-Disposition' header to avoid 'Reflect File Download' vulnerability. [#2541](https://github.com/labstack/echo/pull/2541)
**Enhancements**
* Tests: refactor context tests to be separate functions [#2540](https://github.com/labstack/echo/pull/2540)
* Proxy middleware: reuse echo request context [#2537](https://github.com/labstack/echo/pull/2537)
* Mark unmarshallable yaml struct tags as ignored [#2536](https://github.com/labstack/echo/pull/2536)
## v4.11.2 - 2023-10-11
**Security**
* Bump golang.org/x/net to prevent CVE-2023-39325 / CVE-2023-44487 HTTP/2 Rapid Reset Attack [#2527](https://github.com/labstack/echo/pull/2527)
* fix(sec): randomString bias introduced by #2490 [#2492](https://github.com/labstack/echo/pull/2492)
* CSRF/RequestID mw: switch math/random usage to crypto/random [#2490](https://github.com/labstack/echo/pull/2490)
**Enhancements**
* Delete unused context in body_limit.go [#2483](https://github.com/labstack/echo/pull/2483)
* Use Go 1.21 in CI [#2505](https://github.com/labstack/echo/pull/2505)
* Fix some typos [#2511](https://github.com/labstack/echo/pull/2511)
* Allow CORS middleware to send Access-Control-Max-Age: 0 [#2518](https://github.com/labstack/echo/pull/2518)
* Bump dependancies [#2522](https://github.com/labstack/echo/pull/2522)
## v4.11.1 - 2023-07-16
**Fixes**
* Fix `Gzip` middleware not sending response code for no content responses (404, 301/302 redirects etc) [#2481](https://github.com/labstack/echo/pull/2481)
## v4.11.0 - 2023-07-14
**Fixes**
* Fixes the proxy middleware concurrency issue of calling the Next() proxy target on Round Robin Balancer [#2409](https://github.com/labstack/echo/pull/2409)
* Fix `group.RouteNotFound` not working when group has attached middlewares [#2411](https://github.com/labstack/echo/pull/2411)
* Fix global error handler return error message when message is an error [#2456](https://github.com/labstack/echo/pull/2456)
* Do not use global timeNow variables [#2477](https://github.com/labstack/echo/pull/2477)
**Enhancements**
* Added a optional config variable to disable centralized error handler in recovery middleware [#2410](https://github.com/labstack/echo/pull/2410)
* refactor: use `strings.ReplaceAll` directly [#2424](https://github.com/labstack/echo/pull/2424)
* Add support for Go1.20 `http.rwUnwrapper` to Response struct [#2425](https://github.com/labstack/echo/pull/2425)
* Check whether is nil before invoking centralized error handling [#2429](https://github.com/labstack/echo/pull/2429)
* Proper colon support in `echo.Reverse` method [#2416](https://github.com/labstack/echo/pull/2416)
* Fix misuses of a vs an in documentation comments [#2436](https://github.com/labstack/echo/pull/2436)
* Add link to slog.Handler library for Echo logging into README.md [#2444](https://github.com/labstack/echo/pull/2444)
* In proxy middleware Support retries of failed proxy requests [#2414](https://github.com/labstack/echo/pull/2414)
* gofmt fixes to comments [#2452](https://github.com/labstack/echo/pull/2452)
* gzip response only if it exceeds a minimal length [#2267](https://github.com/labstack/echo/pull/2267)
* Upgrade packages [#2475](https://github.com/labstack/echo/pull/2475)
## v4.10.2 - 2023-02-22
**Security**
* `filepath.Clean` behaviour has changed in Go 1.20 - adapt to it [#2406](https://github.com/labstack/echo/pull/2406)
* Add `middleware.CORSConfig.UnsafeWildcardOriginWithAllowCredentials` to make UNSAFE usages of wildcard origin + allow cretentials less likely [#2405](https://github.com/labstack/echo/pull/2405)
**Enhancements**
* Add more HTTP error values [#2277](https://github.com/labstack/echo/pull/2277)
## v4.10.1 - 2023-02-19
**Security**
* Upgrade deps due to the latest golang.org/x/net vulnerability [#2402](https://github.com/labstack/echo/pull/2402)
**Enhancements**
* Add new JWT repository to the README [#2377](https://github.com/labstack/echo/pull/2377)
* Return an empty string for ctx.path if there is no registered path [#2385](https://github.com/labstack/echo/pull/2385)
* Add context timeout middleware [#2380](https://github.com/labstack/echo/pull/2380)
* Update link to jaegertracing [#2394](https://github.com/labstack/echo/pull/2394)
## v4.10.0 - 2022-12-27
**Security**
* We are deprecating JWT middleware in this repository. Please use https://github.com/labstack/echo-jwt instead.
JWT middleware is moved to separate repository to allow us to bump/upgrade version of JWT implementation (`github.com/golang-jwt/jwt`) we are using
which we can not do in Echo core because this would break backwards compatibility guarantees we try to maintain.
* This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are
several vulnerabilities fixed in these libraries.
Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise.
**Enhancements**
* Bump x/text to 0.3.8 [#2305](https://github.com/labstack/echo/pull/2305)
* Bump dependencies and add notes about Go releases we support [#2336](https://github.com/labstack/echo/pull/2336)
* Add helper interface for ProxyBalancer interface [#2316](https://github.com/labstack/echo/pull/2316)
* Expose `middleware.CreateExtractors` function so we can use it from echo-contrib repository [#2338](https://github.com/labstack/echo/pull/2338)
* Refactor func(Context) error to HandlerFunc [#2315](https://github.com/labstack/echo/pull/2315)
* Improve function comments [#2329](https://github.com/labstack/echo/pull/2329)
* Add new method HTTPError.WithInternal [#2340](https://github.com/labstack/echo/pull/2340)
* Replace io/ioutil package usages [#2342](https://github.com/labstack/echo/pull/2342)
* Add staticcheck to CI flow [#2343](https://github.com/labstack/echo/pull/2343)
* Replace relative path determination from proprietary to std [#2345](https://github.com/labstack/echo/pull/2345)
* Remove square brackets from ipv6 addresses in XFF (X-Forwarded-For header) [#2182](https://github.com/labstack/echo/pull/2182)
* Add testcases for some BodyLimit middleware configuration options [#2350](https://github.com/labstack/echo/pull/2350)
* Additional configuration options for RequestLogger and Logger middleware [#2341](https://github.com/labstack/echo/pull/2341)
* Add route to request log [#2162](https://github.com/labstack/echo/pull/2162)
* GitHub Workflows security hardening [#2358](https://github.com/labstack/echo/pull/2358)
* Add govulncheck to CI and bump dependencies [#2362](https://github.com/labstack/echo/pull/2362)
* Fix rate limiter docs [#2366](https://github.com/labstack/echo/pull/2366)
* Refactor how `e.Routes()` work and introduce `e.OnAddRouteHandler` callback [#2337](https://github.com/labstack/echo/pull/2337)
## v4.9.1 - 2022-10-12
**Fixes**
* Fix logger panicing (when template is set to empty) by bumping dependency version [#2295](https://github.com/labstack/echo/issues/2295)
**Enhancements**
* Improve CORS documentation [#2272](https://github.com/labstack/echo/pull/2272)
* Update readme about supported Go versions [#2291](https://github.com/labstack/echo/pull/2291)
* Tests: improve error handling on closing body [#2254](https://github.com/labstack/echo/pull/2254)
* Tests: refactor some of the assertions in tests [#2275](https://github.com/labstack/echo/pull/2275)
* Tests: refactor assertions [#2301](https://github.com/labstack/echo/pull/2301)
## v4.9.0 - 2022-09-04
**Security**
* Fix open redirect vulnerability in handlers serving static directories (e.Static, e.StaticFs, echo.StaticDirectoryHandler) [#2260](https://github.com/labstack/echo/pull/2260)
**Enhancements**
* Allow configuring ErrorHandler in CSRF middleware [#2257](https://github.com/labstack/echo/pull/2257)
* Replace HTTP method constants in tests with stdlib constants [#2247](https://github.com/labstack/echo/pull/2247)
## v4.8.0 - 2022-08-10
**Most notable things**
You can now add any arbitrary HTTP method type as a route [#2237](https://github.com/labstack/echo/pull/2237)
```go
e.Add("COPY", "/*", func(c echo.Context) error
return c.String(http.StatusOK, "OK COPY")
})
```
You can add custom 404 handler for specific paths [#2217](https://github.com/labstack/echo/pull/2217)
```go
e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })
g := e.Group("/images")
g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })
```
**Enhancements**
* Add new value binding methods (UnixTimeMilli,TextUnmarshaler,JSONUnmarshaler) to Valuebinder [#2127](https://github.com/labstack/echo/pull/2127)
* Refactor: body_limit middleware unit test [#2145](https://github.com/labstack/echo/pull/2145)
* Refactor: Timeout mw: rework how test waits for timeout. [#2187](https://github.com/labstack/echo/pull/2187)
* BasicAuth middleware returns 500 InternalServerError on invalid base64 strings but should return 400 [#2191](https://github.com/labstack/echo/pull/2191)
* Refactor: duplicated findStaticChild process at findChildWithLabel [#2176](https://github.com/labstack/echo/pull/2176)
* Allow different param names in different methods with same path scheme [#2209](https://github.com/labstack/echo/pull/2209)
* Add support for registering handlers for different 404 routes [#2217](https://github.com/labstack/echo/pull/2217)
* Middlewares should use errors.As() instead of type assertion on HTTPError [#2227](https://github.com/labstack/echo/pull/2227)
* Allow arbitrary HTTP method types to be added as routes [#2237](https://github.com/labstack/echo/pull/2237)
## v4.7.2 - 2022-03-16
**Fixes**
* Fix nil pointer exception when calling Start again after address binding error [#2131](https://github.com/labstack/echo/pull/2131)
* Fix CSRF middleware not being able to extract token from multipart/form-data form [#2136](https://github.com/labstack/echo/pull/2136)
* Fix Timeout middleware write race [#2126](https://github.com/labstack/echo/pull/2126)
**Enhancements**
* Recover middleware should not log panic for aborted handler [#2134](https://github.com/labstack/echo/pull/2134)
## v4.7.1 - 2022-03-13
**Fixes**
* Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) [#2123](https://github.com/labstack/echo/pull/2123)
**Enhancements**
* Remove some unused code [#2116](https://github.com/labstack/echo/pull/2116)
## v4.7.0 - 2022-03-01
**Enhancements**
* Add JWT, KeyAuth, CSRF multivalue extractors [#2060](https://github.com/labstack/echo/pull/2060)
* Add LogErrorFunc to recover middleware [#2072](https://github.com/labstack/echo/pull/2072)
* Add support for HEAD method query params binding [#2027](https://github.com/labstack/echo/pull/2027)
* Improve filesystem support with echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS [#2064](https://github.com/labstack/echo/pull/2064)
**Fixes**
* Fix X-Real-IP bug, improve tests [#2007](https://github.com/labstack/echo/pull/2007)
* Minor syntax fixes [#1994](https://github.com/labstack/echo/pull/1994), [#2102](https://github.com/labstack/echo/pull/2102), [#2102](https://github.com/labstack/echo/pull/2102)
**General**
* Add cache-control and connection headers [#2103](https://github.com/labstack/echo/pull/2103)
* Add Retry-After header constant [#2078](https://github.com/labstack/echo/pull/2078)
* Upgrade `go` directive in `go.mod` to 1.17 [#2049](https://github.com/labstack/echo/pull/2049)
* Add Pagoda [#2077](https://github.com/labstack/echo/pull/2077) and Souin [#2069](https://github.com/labstack/echo/pull/2069) to 3rd-party middlewares in README
## v4.6.3 - 2022-01-10
**Fixes**
* Fixed Echo version number in greeting message which was not incremented to `4.6.2` [#2066](https://github.com/labstack/echo/issues/2066)
## v4.6.2 - 2022-01-08
**Fixes**
* Fixed route containing escaped colon should be matchable but is not matched to request path [#2047](https://github.com/labstack/echo/pull/2047)
* Fixed a problem that returned wrong content-encoding when the gzip compressed content was empty. [#1921](https://github.com/labstack/echo/pull/1921)
* Update (test) dependencies [#2021](https://github.com/labstack/echo/pull/2021)
**Enhancements**
* Add support for configurable target header for the request_id middleware [#2040](https://github.com/labstack/echo/pull/2040)
* Change decompress middleware to use stream decompression instead of buffering [#2018](https://github.com/labstack/echo/pull/2018)
* Documentation updates
## v4.6.1 - 2021-09-26
**Enhancements**
* Add start time to request logger middleware values [#1991](https://github.com/labstack/echo/pull/1991)
## v4.6.0 - 2021-09-20
Introduced a new [request logger](https://github.com/labstack/echo/blob/master/middleware/request_logger.go) middleware
to help with cases when you want to use some other logging library in your application.
**Fixes**
* fix timeout middleware warning: superfluous response.WriteHeader [#1905](https://github.com/labstack/echo/issues/1905)
**Enhancements**
* Add Cookie to KeyAuth middleware's KeyLookup [#1929](https://github.com/labstack/echo/pull/1929)
* JWT middleware should ignore case of auth scheme in request header [#1951](https://github.com/labstack/echo/pull/1951)
* Refactor default error handler to return first if response is already committed [#1956](https://github.com/labstack/echo/pull/1956)
* Added request logger middleware which helps to use custom logger library for logging requests. [#1980](https://github.com/labstack/echo/pull/1980)
* Allow escaping of colon in route path so Google Cloud API "custom methods" could be implemented [#1988](https://github.com/labstack/echo/pull/1988)
## v4.5.0 - 2021-08-01
**Important notes**
A **BREAKING CHANGE** is introduced for JWT middleware users.
The JWT library used for the JWT middleware had to be changed from [github.com/dgrijalva/jwt-go](https://github.com/dgrijalva/jwt-go) to
[github.com/golang-jwt/jwt](https://github.com/golang-jwt/jwt) due former library being unmaintained and affected by security
issues.
The [github.com/golang-jwt/jwt](https://github.com/golang-jwt/jwt) project is a drop-in replacement, but supports only the latest 2 Go versions.
So for JWT middleware users Go 1.15+ is required. For detailed information please read [#1940](https://github.com/labstack/echo/discussions/)
To change the library imports in all .go files in your project replace all occurrences of `dgrijalva/jwt-go` with `golang-jwt/jwt`.
For Linux CLI you can use:
```bash
find -type f -name "*.go" -exec sed -i "s/dgrijalva\/jwt-go/golang-jwt\/jwt/g" {} \;
go mod tidy
```
**Fixes**
* Change JWT library to `github.com/golang-jwt/jwt` [#1946](https://github.com/labstack/echo/pull/1946)
## v4.4.0 - 2021-07-12
**Fixes**
* Split HeaderXForwardedFor header only by comma [#1878](https://github.com/labstack/echo/pull/1878)
* Fix Timeout middleware Context propagation [#1910](https://github.com/labstack/echo/pull/1910)
**Enhancements**
* Bind data using headers as source [#1866](https://github.com/labstack/echo/pull/1866)
* Adds JWTConfig.ParseTokenFunc to JWT middleware to allow different libraries implementing JWT parsing. [#1887](https://github.com/labstack/echo/pull/1887)
* Adding tests for Echo#Host [#1895](https://github.com/labstack/echo/pull/1895)
* Adds RequestIDHandler function to RequestID middleware [#1898](https://github.com/labstack/echo/pull/1898)
* Allow for custom JSON encoding implementations [#1880](https://github.com/labstack/echo/pull/1880)
## v4.3.0 - 2021-05-08
**Important notes**
* Route matching has improvements for following cases:
1. Correctly match routes with parameter part as last part of route (with trailing backslash)
2. Considering handlers when resolving routes and search for matching http method handler
* Echo minimal Go version is now 1.13.
**Fixes**
* When url ends with slash first param route is the match [#1804](https://github.com/labstack/echo/pull/1812)
* Router should check if node is suitable as matching route by path+method and if not then continue search in tree [#1808](https://github.com/labstack/echo/issues/1808)
* Fix timeout middleware not writing response correctly when handler panics [#1864](https://github.com/labstack/echo/pull/1864)
* Fix binder not working with embedded pointer structs [#1861](https://github.com/labstack/echo/pull/1861)
* Add Go 1.16 to CI and drop 1.12 specific code [#1850](https://github.com/labstack/echo/pull/1850)
**Enhancements**
* Make KeyFunc public in JWT middleware [#1756](https://github.com/labstack/echo/pull/1756)
* Add support for optional filesystem to the static middleware [#1797](https://github.com/labstack/echo/pull/1797)
* Add a custom error handler to key-auth middleware [#1847](https://github.com/labstack/echo/pull/1847)
* Allow JWT token to be looked up from multiple sources [#1845](https://github.com/labstack/echo/pull/1845)
## v4.2.2 - 2021-04-07
**Fixes**
* Allow proxy middleware to use query part in rewrite (#1802)
* Fix timeout middleware not sending status code when handler returns an error (#1805)
* Fix Bind() when target is array/slice and path/query params complains bind target not being struct (#1835)
* Fix panic in redirect middleware on short host name (#1813)
* Fix timeout middleware docs (#1836)
## v4.2.1 - 2021-03-08
**Important notes**
Due to a datarace the config parameters for the newly added timeout middleware required a change.
See the [docs](https://echo.labstack.com/middleware/timeout).
A performance regression has been fixed, even bringing better performance than before for some routing scenarios.
**Fixes**
* Fix performance regression caused by path escaping (#1777, #1798, #1799, aldas)
* Avoid context canceled errors (#1789, clwluvw)
* Improve router to use on stack backtracking (#1791, aldas, stffabi)
* Fix panic in timeout middleware not being not recovered and cause application crash (#1794, aldas)
* Fix Echo.Serve() not serving on HTTP port correctly when TLSListener is used (#1785, #1793, aldas)
* Apply go fmt (#1788, Le0tk0k)
* Uses strings.Equalfold (#1790, rkilingr)
* Improve code quality (#1792, withshubh)
This release was made possible by our **contributors**:
aldas, clwluvw, lammel, Le0tk0k, maciej-jezierski, rkilingr, stffabi, withshubh
## v4.2.0 - 2021-02-11
**Important notes**
The behaviour for binding data has been reworked for compatibility with echo before v4.1.11 by
enforcing `explicit tagging` for processing parameters. This **may break** your code if you
expect combined handling of query/path/form params.
Please see the updated documentation for [request](https://echo.labstack.com/guide/request) and [binding](https://echo.labstack.com/guide/request)
The handling for rewrite rules has been slightly adjusted to expand `*` to a non-greedy `(.*?)` capture group. This is only relevant if multiple asterisks are used in your rules.
Please see [rewrite](https://echo.labstack.com/middleware/rewrite) and [proxy](https://echo.labstack.com/middleware/proxy) for details.
**Security**
* Fix directory traversal vulnerability for Windows (#1718, little-cui)
* Fix open redirect vulnerability with trailing slash (#1771,#1775 aldas,GeoffreyFrogeye)
**Enhancements**
* Add Echo#ListenerNetwork as configuration (#1667, pafuent)
* Add ability to change the status code using response beforeFuncs (#1706, RashadAnsari)
* Echo server startup to allow data race free access to listener address
* Binder: Restore pre v4.1.11 behaviour for c.Bind() to use query params only for GET or DELETE methods (#1727, aldas)
* Binder: Add separate methods to bind only query params, path params or request body (#1681, aldas)
* Binder: New fluent binder for query/path/form parameter binding (#1717, #1736, aldas)
* Router: Performance improvements for missed routes (#1689, pafuent)
* Router: Improve performance for Real-IP detection using IndexByte instead of Split (#1640, imxyb)
* Middleware: Support real regex rules for rewrite and proxy middleware (#1767)
* Middleware: New rate limiting middleware (#1724, iambenkay)
* Middleware: New timeout middleware implementation for go1.13+ (#1743, )
* Middleware: Allow regex pattern for CORS middleware (#1623, KlotzAndrew)
* Middleware: Add IgnoreBase parameter to static middleware (#1701, lnenad, iambenkay)
* Middleware: Add an optional custom function to CORS middleware to validate origin (#1651, curvegrid)
* Middleware: Support form fields in JWT middleware (#1704, rkfg)
* Middleware: Use sync.Pool for (de)compress middleware to improve performance (#1699, #1672, pafuent)
* Middleware: Add decompress middleware to support gzip compressed requests (#1687, arun0009)
* Middleware: Add ErrJWTInvalid for JWT middleware (#1627, juanbelieni)
* Middleware: Add SameSite mode for CSRF cookies to support iframes (#1524, pr0head)
**Fixes**
* Fix handling of special trailing slash case for partial prefix (#1741, stffabi)
* Fix handling of static routes with trailing slash (#1747)
* Fix Static files route not working (#1671, pwli0755, lammel)
* Fix use of caret(^) in regex for rewrite middleware (#1588, chotow)
* Fix Echo#Reverse for Any type routes (#1695, pafuent)
* Fix Router#Find panic with infinite loop (#1661, pafuent)
* Fix Router#Find panic fails on Param paths (#1659, pafuent)
* Fix DefaultHTTPErrorHandler with Debug=true (#1477, lammel)
* Fix incorrect CORS headers (#1669, ulasakdeniz)
* Fix proxy middleware rewritePath to use url with updated tests (#1630, arun0009)
* Fix rewritePath for proxy middleware to use escaped path in (#1628, arun0009)
* Remove unless defer (#1656, imxyb)
**General**
* New maintainers for Echo: Roland Lammel (@lammel) and Pablo Andres Fuente (@pafuent)
* Add GitHub action to compare benchmarks (#1702, pafuent)
* Binding query/path params and form fields to struct only works for explicit tags (#1729,#1734, aldas)
* Add support for Go 1.15 in CI (#1683, asahasrabuddhe)
* Add test for request id to remain unchanged if provided (#1719, iambenkay)
* Refactor echo instance listener access and startup to speed up testing (#1735, aldas)
* Refactor and improve various tests for binding and routing
* Run test workflow only for relevant changes (#1637, #1636, pofl)
* Update .travis.yml (#1662, santosh653)
* Update README.md with an recents framework benchmark (#1679, pafuent)
This release was made possible by **over 100 commits** from more than **20 contributors**:
asahasrabuddhe, aldas, AndrewKlotz, arun0009, chotow, curvegrid, iambenkay, imxyb,
juanbelieni, lammel, little-cui, lnenad, pafuent, pofl, pr0head, pwli, RashadAnsari,
rkfg, santosh653, segfiner, stffabi, ulasakdeniz
================================================
FILE: CLAUDE.md
================================================
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## About This Project
Echo is a high performance, minimalist Go web framework. This is the main repository for Echo v4, which is available as a Go module at `github.com/labstack/echo/v4`.
## Development Commands
The project uses a Makefile for common development tasks:
- `make check` - Run linting, vetting, and race condition tests (default target)
- `make init` - Install required linting tools (golint, staticcheck)
- `make lint` - Run staticcheck and golint
- `make vet` - Run go vet
- `make test` - Run short tests
- `make race` - Run tests with race detector
- `make benchmark` - Run benchmarks
Example commands for development:
```bash
# Setup development environment
make init
# Run all checks (lint, vet, race)
make check
# Run specific tests
go test ./middleware/...
go test -race ./...
# Run benchmarks
make benchmark
```
## Code Architecture
### Core Components
**Echo Instance (`echo.go`)**
- The `Echo` struct is the top-level framework instance
- Contains router, middleware stacks, and server configuration
- Not goroutine-safe for mutations after server start
**Context (`context.go`)**
- The `Context` interface represents HTTP request/response context
- Provides methods for request/response handling, path parameters, data binding
- Core abstraction for request processing
**Router (`router.go`)**
- Radix tree-based HTTP router with smart route prioritization
- Supports static routes, parameterized routes (`/users/:id`), and wildcard routes (`/static/*`)
- Each HTTP method has its own routing tree
**Middleware (`middleware/`)**
- Extensive middleware system with 50+ built-in middlewares
- Middleware can be applied at Echo, Group, or individual route level
- Common middleware: Logger, Recover, CORS, JWT, Rate Limiting, etc.
### Key Patterns
**Middleware Chain**
- Pre-middleware runs before routing
- Regular middleware runs after routing but before handlers
- Middleware functions have signature `func(next echo.HandlerFunc) echo.HandlerFunc`
**Route Groups**
- Routes can be grouped with common prefixes and middleware
- Groups support nested sub-groups
- Defined in `group.go`
**Data Binding**
- Automatic binding of request data (JSON, XML, form) to Go structs
- Implemented in `binder.go` with support for custom binders
**Error Handling**
- Centralized error handling via `HTTPErrorHandler`
- Automatic panic recovery with stack traces
## File Organization
- Root directory: Core Echo functionality (echo.go, context.go, router.go, etc.)
- `middleware/`: All built-in middleware implementations
- `_test/`: Test fixtures and utilities
- `_fixture/`: Test data files
## Code Style
- Go code uses tabs for indentation (per .editorconfig)
- Follows standard Go conventions and formatting
- Uses gofmt, golint, and staticcheck for code quality
## Testing
- Standard Go testing with `testing` package
- Tests include unit tests, integration tests, and benchmarks
- Race condition testing is required (`make race`)
- Test files follow `*_test.go` naming convention
================================================
FILE: LICENSE
================================================
The MIT License (MIT)
Copyright (c) 2022 LabStack
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: Makefile
================================================
PKG := "github.com/labstack/echo"
PKG_LIST := $(shell go list ${PKG}/...)
.DEFAULT_GOAL := check
check: lint vet race ## Check project
init:
@go install golang.org/x/lint/golint@latest
@go install honnef.co/go/tools/cmd/staticcheck@latest
lint: ## Lint the files
@staticcheck ${PKG_LIST}
@golint -set_exit_status ${PKG_LIST}
vet: ## Vet the files
@go vet ${PKG_LIST}
test: ## Run tests
@go test -short ${PKG_LIST}
race: ## Run tests with data race detector
@go test -race ${PKG_LIST}
benchmark: ## Run benchmarks
@go test -run="-" -benchmem -bench=".*" ${PKG_LIST}
help: ## Display this help screen
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
goversion ?= "1.25"
test_version: ## Run tests inside Docker with given version (defaults to 1.25 oldest supported). Example: make test_version goversion=1.25
@docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"
================================================
FILE: README.md
================================================
[](https://sourcegraph.com/github.com/labstack/echo?badge)
[](https://pkg.go.dev/github.com/labstack/echo/v4)
[](https://goreportcard.com/report/github.com/labstack/echo)
[](https://github.com/labstack/echo/actions)
[](https://codecov.io/gh/labstack/echo)
[](https://github.com/labstack/echo/discussions)
[](https://twitter.com/labstack)
[](https://raw.githubusercontent.com/labstack/echo/master/LICENSE)
## Echo
High performance, extensible, minimalist Go web framework.
* [Official website](https://echo.labstack.com)
* [Quick start](https://echo.labstack.com/docs/quick-start)
* [Middlewares](https://echo.labstack.com/docs/category/middleware)
Help and questions: [Github Discussions](https://github.com/labstack/echo/discussions)
### Feature Overview
- Optimized HTTP router which smartly prioritize routes
- Build robust and scalable RESTful APIs
- Group APIs
- Extensible middleware framework
- Define middleware at root, group or route level
- Data binding for JSON, XML and form payload
- Handy functions to send variety of HTTP responses
- Centralized HTTP error handling
- Template rendering with any template engine
- Define your format for the logger
- Highly customizable
- Automatic TLS via Let’s Encrypt
- HTTP/2 support
## Sponsors
Click [here](https://github.com/sponsors/labstack) for more information on sponsorship.
## [Guide](https://echo.labstack.com/guide)
### Supported Echo versions
- Latest major version of Echo is `v5` as of 2026-01-18.
- Until 2026-03-31, any critical issues requiring breaking API changes will be addressed, even if this violates
semantic versioning.
- See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on upgrading.
- If you are using Echo in a production environment, it is recommended to wait until after 2026-03-31 before
upgrading.
- Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31**
### Installation
```sh
// go get github.com/labstack/echo/{version}
go get github.com/labstack/echo/v5
```
Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with
older versions.
### Example
```go
package main
import (
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
"log/slog"
"net/http"
)
func main() {
// Echo instance
e := echo.New()
// Middleware
e.Use(middleware.RequestLogger()) // use the RequestLogger middleware with slog logger
e.Use(middleware.Recover()) // recover panics as errors for proper error handling
// Routes
e.GET("/", hello)
// Start server
if err := e.Start(":8080"); err != nil {
slog.Error("failed to start server", "error", err)
}
}
// Handler
func hello(c *echo.Context) error {
return c.String(http.StatusOK, "Hello, World!")
}
```
# Official middleware repositories
Following list of middleware is maintained by Echo team.
| Repository | Description |
|------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware |
| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [pprof](https://pkg.go.dev/net/http/pprof)) middlewares |
| [github.com/labstack/echo-opentelemetry](https://github.com/labstack/echo-opentelemetry) | [OpenTelemetry](https://opentelemetry.io/) middleware for tracing and metrics |
| [github.com/labstack/echo-prometheus](https://github.com/labstack/echo-prometheus) | [Prometheus](https://github.com/prometheus/client_golang/) middleware for Echo |
# Third-party middleware repositories
Be careful when adding 3rd party middleware. Echo teams does not have time or manpower to guarantee safety and quality
of middlewares in this list.
| Repository | Description |
|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [oapi-codegen/oapi-codegen](https://github.com/oapi-codegen/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator |
| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. |
| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. |
| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. |
| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface. |
| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. |
| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. |
| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code |
Please send a PR to add your own library here.
## Contribute
**Use issues for everything**
- For a small change, just send a PR.
- For bigger changes open an issue for discussion before sending a PR.
- PR should have:
- Test case
- Documentation
- Example (If it makes sense)
- You can also contribute by:
- Reporting issues
- Suggesting new features or enhancements
- Improve/fix documentation
## Credits
- [Vishal Rana](https://github.com/vishr) (Author)
- [Nitin Rana](https://github.com/nr17) (Consultant)
- [Roland Lammel](https://github.com/lammel) (Maintainer)
- [Martti T.](https://github.com/aldas) (Maintainer)
- [Pablo Andres Fuente](https://github.com/pafuent) (Maintainer)
- [Contributors](https://github.com/labstack/echo/graphs/contributors)
## License
[MIT](https://github.com/labstack/echo/blob/master/LICENSE)
================================================
FILE: SECURITY.md
================================================
# Security Policy
## Supported Versions
| Version | Supported |
|-----------|-------------------------------------|
| 5.x.x | :white_check_mark: |
| >= 4.15.x | :white_check_mark: until 2026.12.31 |
| < 4.15 | :x: |
## Reporting a Vulnerability
https://github.com/labstack/echo/security/advisories/new
or look for maintainers email(s) in commits and email them.
================================================
FILE: _fixture/_fixture/README.md
================================================
This directory is used for the static middleware test
================================================
FILE: _fixture/certs/README.md
================================================
To generate a valid certificate and private key use the following command:
```bash
# In OpenSSL ≥ 1.1.1
openssl req -x509 -newkey rsa:4096 -sha256 -days 9999 -nodes \
-keyout key.pem -out cert.pem -subj "/CN=localhost" \
-addext "subjectAltName=DNS:localhost,IP:127.0.0.1,IP:::1"
```
To check a certificate use the following command:
```bash
openssl x509 -in cert.pem -text
```
================================================
FILE: _fixture/certs/cert.pem
================================================
-----BEGIN CERTIFICATE-----
MIIFODCCAyCgAwIBAgIUaTvDluaMf+VJgYHQ0HFTS3yuCHYwDQYJKoZIhvcNAQEL
BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIxMDIyNzIxMzQ0MVoXDTQ4MDcx
NDIxMzQ0MVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEF
AAOCAg8AMIICCgKCAgEAnqyyAAnWFH2TH7Epj5yfZxYrBvizydZe1Wo/1WpGR2IK
QT+qIul5sEKX/ERqEOXsawSrL3fw9cuSM8Z2vD/57ZZdoSR7XIdVaMDEQenJ968a
HObu4D27uBQwIwrM5ELgnd+fC4gis64nIu+2GSfHumZXi7lLW7DbNm8oWkMqI6tY
2s2wx2hwGYNVJrwSn4WGnkzhQ5U5mkcsLELMx7GR0Qnv6P7sNGZVeqMU7awkcSpR
crKR1OUP7XCJkEq83WLHSx50+QZv7LiyDmGnujHevRbdSHlcFfHZtaufYat+qICe
S3XADwRQe/0VSsmja6u3DAHy7VmL8PNisAdkopQZrhiI9OvGrpGZffs9zn+s/jeX
N1bqVDihCMiEjqXMlHx2oj3AXrZTFxb7y7Ap9C07nf70lpxQWW9SjMYRF98JBiHF
eJbQkNVkmz6T8ielQbX0l46F2SGK98oyFCGNIAZBUdj5CcS1E6w/lk4t58/em0k7
3wFC5qg0g0wfIbNSmxljBNxnaBYUqyaaAJJhpaEoOebm4RYV58hQ0FbMfpnLnSh4
dYStsk6i1PumWoa7D45DTtxF3kH7TB3YOB5aWaNGAPQC1m4Qcd23YB5Rd/ABirSp
ux6/cFGosjSfJ/G+G0RhNUpmcbDJvFSOhD2WCuieVhCTAzp+VPIA9bSqD+InlT0C
AwEAAaOBgTB/MB0GA1UdDgQWBBQZyM//SvzYKokQZI/0MVGb6PkH+zAfBgNVHSME
GDAWgBQZyM//SvzYKokQZI/0MVGb6PkH+zAPBgNVHRMBAf8EBTADAQH/MCwGA1Ud
EQQlMCOCCWxvY2FsaG9zdIcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG
9w0BAQsFAAOCAgEAKGAJQmQ/KLw8iMb5QsyxxAonVjJ1eDAhNM3GWdHpM0/GFamO
vVtATLQQldwDiZJvrsCQPEc8ctZ2Utvg/StLQ3+rZpsvt0+gcUlLJK61qguwYqb2
+T7VK5s7V/OyI/tsuboOW50Pka9vQHV+Z0aM06Yu+HNDAq/UTpEOb/3MQvZd6Ooy
PTpZtFb/+5jIQa1dIsfFWmpBxF0+wUd9GEkX3j7nekwoZfJ8Ze4GWYERZbOFpDAQ
rIHdthH5VJztnpQJmaKqzgIOF+Rurwlp5ecSC33xNNjDaYtuf/fiWnoKGhHVSBhT
61+0yxn3rTgh/Dsm95xY00rSX6lmcvI+kRNTUc8GGPz0ajBH6xyY7bNhfMjmnSW/
C/XTEDbTAhT7ndWC5vvzp7ZU0TvN+WY6A0f2kxSnnrEk6QRUvRtKkjAkmAFz8exi
ttBBW0I3E5HNIC5CYRimq/9z+3clM/P1KbNblwuC65bL+PZ+nzFnn5hFaK9eLPol
OwZQXv7IvAw8GfgLTrEUT7eBCQwe1IqesA7NTxF1BVwmNUb2XamvQZ7ly67QybRw
0uJq80XjpVjBWYTTQy1dsnC2OTKdqGsV9TVIDR+UGfIG9cxL70pEbiSH2AX+IDCy
i3kNIvpXgBliAyOjW6Hj1fv6dNfAat/hqEfnquWkfvcs3HNrG/InwpwNAUs=
-----END CERTIFICATE-----
================================================
FILE: _fixture/certs/key.pem
================================================
-----BEGIN PRIVATE KEY-----
MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCerLIACdYUfZMf
sSmPnJ9nFisG+LPJ1l7Vaj/VakZHYgpBP6oi6XmwQpf8RGoQ5exrBKsvd/D1y5Iz
xna8P/ntll2hJHtch1VowMRB6cn3rxoc5u7gPbu4FDAjCszkQuCd358LiCKzrici
77YZJ8e6ZleLuUtbsNs2byhaQyojq1jazbDHaHAZg1UmvBKfhYaeTOFDlTmaRyws
QszHsZHRCe/o/uw0ZlV6oxTtrCRxKlFyspHU5Q/tcImQSrzdYsdLHnT5Bm/suLIO
Yae6Md69Ft1IeVwV8dm1q59hq36ogJ5LdcAPBFB7/RVKyaNrq7cMAfLtWYvw82Kw
B2SilBmuGIj068aukZl9+z3Of6z+N5c3VupUOKEIyISOpcyUfHaiPcBetlMXFvvL
sCn0LTud/vSWnFBZb1KMxhEX3wkGIcV4ltCQ1WSbPpPyJ6VBtfSXjoXZIYr3yjIU
IY0gBkFR2PkJxLUTrD+WTi3nz96bSTvfAULmqDSDTB8hs1KbGWME3GdoFhSrJpoA
kmGloSg55ubhFhXnyFDQVsx+mcudKHh1hK2yTqLU+6ZahrsPjkNO3EXeQftMHdg4
HlpZo0YA9ALWbhBx3bdgHlF38AGKtKm7Hr9wUaiyNJ8n8b4bRGE1SmZxsMm8VI6E
PZYK6J5WEJMDOn5U8gD1tKoP4ieVPQIDAQABAoICAEHF2CsH6MOpofi7GT08cR7s
I33KTcxWngzc9ATk/qjMTO/rEf1Sxmx3zkR1n3nNtQhPcR5GG43nin0HwWQbKOCB
OeJ4GuKp/o9jiHbCEEQpQyvD1jUBofSV+bYs3e2ogy8t6OGA1tGgWPy0XMlkoff0
QEnczw3864FO5m0z9h2/Ax//r02ZTw5kUEG0KAwT709jEuVO0AfRhM/8CKKmSola
EyaDtSmrWbdyLlSuzJRUNFrVBno3UTjdM0iqkks6jN3ojBhFwNNhY/1uIXafAXNk
LOnD1JYMIHCb6X809VWnqvYgozIWWb5rlA3iM2mITmId1LLqMYX5fWj2R5LUzSek
H+XG+F9FIouTaL1ACoXr0zyeY5N5YJdyXYa1tThdW+axX9ZrnPgeiQrmxzKPIyb7
LLlVtNBQUg/t5tX80KyYjkNUu4j3oq/uBYPi0m//ovwMyi9bSbbyPT+cDXuXX5Bc
oY7wyn3evXX0c1R7vdJLZLkLu+ctVex/9hvMjeW/mMasDjLnqY7pF3Skct1SX5N2
U8YVU9bGvFpLEwM9lmi/T7bcv+zbmGPlfTsZiFrCsixPLn7sX7y5M4L8au8O0jh0
nHm/8rWVg1Qw0Hobg3tA8FjeMa8Sr2fYmkNLVKFzhuJLxknTJLaUbX5CymNqWP4H
OctvfSY0nSZ1eQpBkQaJAoIBAQDTb/NhYCfaJBLXHVMy/VYd7kWGZ+I87artcE/l
8u0pJ8XOP4kp0otFIumpHUFodysAeP6HrI79MuJB40fy91HzWZC+NrPufFFFuZ0z
Ld1o3Y5nAeoZmMlf1F12Oe3OQZy7nm9eNNkfeoVtKqDv4FhAqk+aoMor86HscKsR
C6HlZFdGc7kX0ylrQAXPq9KLhcvUU9oAUpbqTbhYK83IebRJgFDG45HkVo9SUHpF
dmCFSb91eZpRGpdfNLCuLiSu52TebayaUCnceeAt8SyeiChJ/TwWmRRDJS0QUv6h
s3Wdp+cx9ANoujA4XzAs8Fld5IZ4bcG5jjwD62/tJyWrCC5DAoIBAQDAHfHjrYCK
GHBrMj+MA7cK7fCJUn/iJLSLGgo2ANYF5oq9gaCwHCtKIyB9DN/KiY0JpJ6PWg+Q
9Difq23YXiJjNEBS5EFTu9UwWAr1RhSAegrfHxm0sDbcAx31NtDYvBsADCWQYmzc
KPfBshf5K4g/VCIj2VzC2CE6kNtdhqLU6AV2Pi1Tl1S82xWoAjHy91tDmlFQNWCj
B2ZnZ7tY9zuwDfeBBOVCPHICgl5Q4PrY1KEWEXiNxgbtkNmOPAsY9WSqgOsP9pWK
J924gdCCvovINzZtgRisxKth6Fkhra+VCsheg9SWvgR09Deo6CCoSwYxOSb0cjh2
oyX5Rb1kJ7Z/AoIBAQCX2iNVoBV/GcFeNXV3fXLH9ESCj0FwuNC1zp/TanDhyerK
gd8k5k2Xzcc66gP73vpHUJ6dGlVni4/r+ivGV9HHkF/f/LGlaiuEhBZel2YY1mZb
nIhg8dZOuNqW+mvMYlsKdHNPmW0GqpwBF0iWfu1jI+4gA7Kvdj6o7RIvH8eaVEJK
GvqoHcP1fvmteJ2yDtmhGMfMy4QPqtnmmS8l+CJ/V2SsMuyorXIpkBsAoFAZ6ilT
WY53CT4F5nWt4v39j7pl9SatfT1TV0SmOjvtb6Rf3zu0jyR6RMzkmHa/839ZRylI
OxPntzDCi7qxy7yjLmlVPJ6RgZGgzwqHrEHlX+65AoIBAQCEzu6d3x5B2N02LZli
eFr8MjqbI64GLiulEY5HgNJzZ8k3cjocJI0Ehj36VIEMaYRXSzbVkIO8SCgwsPiR
n5mUDNX+t441jV62Odbxcc3Qdw226rABieOSupDmKEu92GOt57e8FV5939BOVYhf
FunsJYQoViXbCEAIVYVgJSfBmNfVwuvgonfQyn8xErtm4/pyRGa71PqGGSKAj2Qi
/16CuVUFGtZFsLV76JW8wZqHdI4bTF6TW3cEmaLbwcRGL7W0bMSS13rO8/pBh3QW
PhUxhoGYt6rQHHEBkPa04nXDyZ10QRwgTSGVnBIyMK4KyTpxorm8OI2x7dzdcomX
iCCPAoIBAETwfr2JKPb/AzrKhhbZgU+sLVn3WH/nb68VheNEmGOzsqXaSHCR2NOq
/ow7bawjc8yUIhBRzokR4F/7jGolOmfdq0MYFb6/YokssKfv1ugxBhmvOxpZ6F6E
cERJ8Ex/ffQU053gLR/0ammddVuS1GR5I/jEdP0lJVh0xapoZNUlT5dWYCgo20hY
ZAmKpU+veyUn+5Li0pmm959vnLK5LJzEA5mpz3w1QPPtVwQs05dwmEV3CRAcCeeh
8sXp49WNCSW4I3BxuTZzRV845SGIFhZwgVV42PTp2LPKl2p6E7Bk8xpUCCvBpALp
QmA5yIMx+u2Jpr7fUsXEXEPTEhvjff0=
-----END PRIVATE KEY-----
================================================
FILE: _fixture/dist/private.txt
================================================
private file
================================================
FILE: _fixture/dist/public/assets/readme.md
================================================
readme in assets
================================================
FILE: _fixture/dist/public/assets/subfolder/subfolder.md
================================================
file inside subfolder
================================================
FILE: _fixture/dist/public/index.html
================================================
Hello from index
================================================
FILE: _fixture/dist/public/test.txt
================================================
test.txt contents
================================================
FILE: _fixture/folder/index.html
================================================
Echo
================================================
FILE: _fixture/index.html
================================================
Echo
================================================
FILE: bind.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"encoding"
"encoding/xml"
"errors"
"mime/multipart"
"net/http"
"reflect"
"strconv"
"strings"
"time"
)
// Binder is the interface that wraps the Bind method.
type Binder interface {
Bind(c *Context, target any) error
}
// DefaultBinder is the default implementation of the Binder interface.
type DefaultBinder struct{}
// BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
// Types that don't implement this, but do implement encoding.TextUnmarshaler
// will use that interface instead.
type BindUnmarshaler interface {
// UnmarshalParam decodes and assigns a value from an form or query param.
UnmarshalParam(param string) error
}
// bindMultipleUnmarshaler is used by binder to unmarshal multiple values from request at once to
// type implementing this interface. For example request could have multiple query fields `?a=1&a=2&b=test` in that case
// for `a` following slice `["1", "2"] will be passed to unmarshaller.
type bindMultipleUnmarshaler interface {
UnmarshalParams(params []string) error
}
// BindPathValues binds path parameter values to bindable object
func BindPathValues(c *Context, target any) error {
params := map[string][]string{}
for _, param := range c.PathValues() {
params[param.Name] = []string{param.Value}
}
if err := bindData(target, params, "param", nil); err != nil {
return ErrBadRequest.Wrap(err)
}
return nil
}
// BindQueryParams binds query params to bindable object
func BindQueryParams(c *Context, target any) error {
if err := bindData(target, c.QueryParams(), "query", nil); err != nil {
return ErrBadRequest.Wrap(err)
}
return nil
}
// BindBody binds request body contents to bindable object
// NB: then binding forms take note that this implementation uses standard library form parsing
// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm
// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm
func BindBody(c *Context, target any) (err error) {
req := c.Request()
if req.ContentLength == 0 {
return
}
// mediatype is found like `mime.ParseMediaType()` does it
base, _, _ := strings.Cut(req.Header.Get(HeaderContentType), ";")
mediatype := strings.TrimSpace(base)
switch mediatype {
case MIMEApplicationJSON:
if err = c.Echo().JSONSerializer.Deserialize(c, target); err != nil {
var hErr *HTTPError
if errors.As(err, &hErr) {
return err
}
return ErrBadRequest.Wrap(err)
}
case MIMEApplicationXML, MIMETextXML:
if err = xml.NewDecoder(req.Body).Decode(target); err != nil {
return ErrBadRequest.Wrap(err)
}
case MIMEApplicationForm:
params, err := c.FormValues()
if err != nil {
return ErrBadRequest.Wrap(err)
}
if err = bindData(target, params, "form", nil); err != nil {
return ErrBadRequest.Wrap(err)
}
case MIMEMultipartForm:
params, err := c.MultipartForm()
if err != nil {
return ErrBadRequest.Wrap(err)
}
if err = bindData(target, params.Value, "form", params.File); err != nil {
return ErrBadRequest.Wrap(err)
}
default:
return &HTTPError{Code: http.StatusUnsupportedMediaType}
}
return nil
}
// BindHeaders binds HTTP headers to a bindable object
func BindHeaders(c *Context, target any) error {
if err := bindData(target, c.Request().Header, "header", nil); err != nil {
return ErrBadRequest.Wrap(err)
}
return nil
}
// Bind implements the `Binder#Bind` function.
// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous
// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathValues.
func (b *DefaultBinder) Bind(c *Context, target any) error {
if err := BindPathValues(c, target); err != nil {
return err
}
// Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body.
// For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues.
// The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670)
method := c.Request().Method
if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead {
if err := BindQueryParams(c, target); err != nil {
return err
}
}
return BindBody(c, target)
}
// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag
func bindData(destination any, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error {
if destination == nil || (len(data) == 0 && len(dataFiles) == 0) {
return nil
}
hasFiles := len(dataFiles) > 0
typ := reflect.TypeOf(destination).Elem()
val := reflect.ValueOf(destination).Elem()
// Support binding to limited Map destinations:
// - map[string][]string,
// - map[string]string <-- (binds first value from data slice)
// - map[string]any
// You are better off binding to struct but there are user who want this map feature. Source of data for these cases are:
// params,query,header,form as these sources produce string values, most of the time slice of strings, actually.
if typ.Kind() == reflect.Map && typ.Key().Kind() == reflect.String {
k := typ.Elem().Kind()
isElemInterface := k == reflect.Interface
isElemString := k == reflect.String
isElemSliceOfStrings := k == reflect.Slice && typ.Elem().Elem().Kind() == reflect.String
if !(isElemSliceOfStrings || isElemString || isElemInterface) {
return nil
}
if val.IsNil() {
val.Set(reflect.MakeMap(typ))
}
for k, v := range data {
if isElemString {
val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0]))
} else if isElemInterface {
// To maintain backward compatibility, we always bind to the first string value
// and not the slice of strings when dealing with map[string]any{}
val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0]))
} else {
val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v))
}
}
return nil
}
// !struct
if typ.Kind() != reflect.Struct {
if tag == "param" || tag == "query" || tag == "header" {
// incompatible type, data is probably to be found in the body
return nil
}
return errors.New("binding element must be a struct")
}
for i := 0; i < typ.NumField(); i++ { // iterate over all destination fields
typeField := typ.Field(i)
structField := val.Field(i)
if typeField.Anonymous {
if structField.Kind() == reflect.Ptr {
structField = structField.Elem()
}
}
if !structField.CanSet() {
continue
}
structFieldKind := structField.Kind()
inputFieldName := typeField.Tag.Get(tag)
if typeField.Anonymous && structFieldKind == reflect.Struct && inputFieldName != "" {
// if anonymous struct with query/param/form tags, report an error
return errors.New("query/param/form tags are not allowed with anonymous struct field")
}
if inputFieldName == "" {
// If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags).
// structs that implement BindUnmarshaler are bound only when they have explicit tag
if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
if err := bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil {
return err
}
}
// does not have explicit tag and is not an ordinary struct - so move to next field
continue
}
if hasFiles {
if ok, err := isFieldMultipartFile(structField.Type()); err != nil {
return err
} else if ok {
if ok := setMultipartFileHeaderTypes(structField, inputFieldName, dataFiles); ok {
continue
}
}
}
inputValue, exists := data[inputFieldName]
if !exists {
// Go json.Unmarshal supports case-insensitive binding. However the
// url params are bound case-sensitive which is inconsistent. To
// fix this we must check all of the map values in a
// case-insensitive search.
for k, v := range data {
if strings.EqualFold(k, inputFieldName) {
inputValue = v
exists = true
break
}
}
}
if !exists {
continue
}
// NOTE: algorithm here is not particularly sophisticated. It probably does not work with absurd types like `**[]*int`
// but it is smart enough to handle niche cases like `*int`,`*[]string`,`[]*int` .
// try unmarshalling first, in case we're dealing with an alias to an array type
if ok, err := unmarshalInputsToField(typeField.Type.Kind(), inputValue, structField); ok {
if err != nil {
return err
}
continue
}
formatTag := typeField.Tag.Get("format")
if ok, err := unmarshalInputToField(typeField.Type.Kind(), inputValue[0], structField, formatTag); ok {
if err != nil {
return err
}
continue
}
// we could be dealing with pointer to slice `*[]string` so dereference it. There are weird OpenAPI generators
// that could create struct fields like that.
if structFieldKind == reflect.Pointer {
structFieldKind = structField.Elem().Kind()
structField = structField.Elem()
}
if structFieldKind == reflect.Slice {
sliceOf := structField.Type().Elem().Kind()
numElems := len(inputValue)
slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
for j := 0; j < numElems; j++ {
if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil {
return err
}
}
structField.Set(slice)
continue
}
if err := setWithProperType(structFieldKind, inputValue[0], structField); err != nil {
return err
}
}
return nil
}
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
// But also call it here, in case we're dealing with an array of BindUnmarshalers
// Note: format tag not available in this context, so empty string is passed
if ok, err := unmarshalInputToField(valueKind, val, structField, ""); ok {
return err
}
switch valueKind {
case reflect.Ptr:
return setWithProperType(structField.Elem().Kind(), val, structField.Elem())
case reflect.Int:
return setIntField(val, 0, structField)
case reflect.Int8:
return setIntField(val, 8, structField)
case reflect.Int16:
return setIntField(val, 16, structField)
case reflect.Int32:
return setIntField(val, 32, structField)
case reflect.Int64:
return setIntField(val, 64, structField)
case reflect.Uint:
return setUintField(val, 0, structField)
case reflect.Uint8:
return setUintField(val, 8, structField)
case reflect.Uint16:
return setUintField(val, 16, structField)
case reflect.Uint32:
return setUintField(val, 32, structField)
case reflect.Uint64:
return setUintField(val, 64, structField)
case reflect.Bool:
return setBoolField(val, structField)
case reflect.Float32:
return setFloatField(val, 32, structField)
case reflect.Float64:
return setFloatField(val, 64, structField)
case reflect.String:
structField.SetString(val)
default:
return errors.New("unknown type")
}
return nil
}
func unmarshalInputsToField(valueKind reflect.Kind, values []string, field reflect.Value) (bool, error) {
if valueKind == reflect.Ptr {
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
field = field.Elem()
}
fieldIValue := field.Addr().Interface()
unmarshaler, ok := fieldIValue.(bindMultipleUnmarshaler)
if !ok {
return false, nil
}
return true, unmarshaler.UnmarshalParams(values)
}
func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Value, formatTag string) (bool, error) {
if valueKind == reflect.Ptr {
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
field = field.Elem()
}
fieldIValue := field.Addr().Interface()
// Handle time.Time with custom format tag
if formatTag != "" {
if _, isTime := fieldIValue.(*time.Time); isTime {
t, err := time.Parse(formatTag, val)
if err != nil {
return true, err
}
field.Set(reflect.ValueOf(t))
return true, nil
}
}
switch unmarshaler := fieldIValue.(type) {
case BindUnmarshaler:
return true, unmarshaler.UnmarshalParam(val)
case encoding.TextUnmarshaler:
return true, unmarshaler.UnmarshalText([]byte(val))
}
return false, nil
}
func setIntField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0"
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err == nil {
field.SetInt(intVal)
}
return err
}
func setUintField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0"
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err == nil {
field.SetUint(uintVal)
}
return err
}
func setBoolField(value string, field reflect.Value) error {
if value == "" {
value = "false"
}
boolVal, err := strconv.ParseBool(value)
if err == nil {
field.SetBool(boolVal)
}
return err
}
func setFloatField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0.0"
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err == nil {
field.SetFloat(floatVal)
}
return err
}
var (
// NOT supported by bind as you can NOT check easily empty struct being actual file or not
multipartFileHeaderType = reflect.TypeFor[multipart.FileHeader]()
// supported by bind as you can check by nil value if file existed or not
multipartFileHeaderPointerType = reflect.TypeFor[*multipart.FileHeader]()
multipartFileHeaderSliceType = reflect.TypeFor[[]multipart.FileHeader]()
multipartFileHeaderPointerSliceType = reflect.TypeFor[[]*multipart.FileHeader]()
)
func isFieldMultipartFile(field reflect.Type) (bool, error) {
switch field {
case multipartFileHeaderPointerType,
multipartFileHeaderSliceType,
multipartFileHeaderPointerSliceType:
return true, nil
case multipartFileHeaderType:
return true, errors.New("binding to multipart.FileHeader struct is not supported, use pointer to struct")
default:
return false, nil
}
}
func setMultipartFileHeaderTypes(structField reflect.Value, inputFieldName string, files map[string][]*multipart.FileHeader) bool {
fileHeaders := files[inputFieldName]
if len(fileHeaders) == 0 {
return false
}
result := true
switch structField.Type() {
case multipartFileHeaderPointerSliceType:
structField.Set(reflect.ValueOf(fileHeaders))
case multipartFileHeaderSliceType:
headers := make([]multipart.FileHeader, len(fileHeaders))
for i, fileHeader := range fileHeaders {
headers[i] = *fileHeader
}
structField.Set(reflect.ValueOf(headers))
case multipartFileHeaderPointerType:
structField.Set(reflect.ValueOf(fileHeaders[0]))
default:
result = false
}
return result
}
================================================
FILE: bind_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"bytes"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"reflect"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type bindTestStruct struct {
T Timestamp
GoT time.Time
PtrI16 *int16
PtrUI *uint
Tptr *Timestamp
PtrF32 *float32
PtrB *bool
PtrI32 *int32
GoTptr *time.Time
PtrI64 *int64
PtrI *int
PtrI8 *int8
PtrF64 *float64
PtrUI8 *uint8
PtrUI64 *uint64
PtrUI16 *uint16
PtrS *string
PtrUI32 *uint32
S string
cantSet string
DoesntExist string
SA StringArray
F64 float64
I int
UI64 uint64
UI uint
I64 int64
F32 float32
UI32 uint32
I32 int32
UI16 uint16
I16 int16
B bool
UI8 uint8
I8 int8
}
type bindTestStructWithTags struct {
T Timestamp `json:"T" form:"T"`
GoT time.Time `json:"GoT" form:"GoT"`
PtrI16 *int16 `json:"PtrI16" form:"PtrI16"`
PtrUI *uint `json:"PtrUI" form:"PtrUI"`
Tptr *Timestamp `json:"Tptr" form:"Tptr"`
PtrF32 *float32 `json:"PtrF32" form:"PtrF32"`
PtrB *bool `json:"PtrB" form:"PtrB"`
PtrI32 *int32 `json:"PtrI32" form:"PtrI32"`
GoTptr *time.Time `json:"GoTptr" form:"GoTptr"`
PtrI64 *int64 `json:"PtrI64" form:"PtrI64"`
PtrI *int `json:"PtrI" form:"PtrI"`
PtrI8 *int8 `json:"PtrI8" form:"PtrI8"`
PtrF64 *float64 `json:"PtrF64" form:"PtrF64"`
PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"`
PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"`
PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"`
PtrS *string `json:"PtrS" form:"PtrS"`
PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"`
S string `json:"S" form:"S"`
cantSet string
DoesntExist string `json:"DoesntExist" form:"DoesntExist"`
SA StringArray `json:"SA" form:"SA"`
F64 float64 `json:"F64" form:"F64"`
I int `json:"I" form:"I"`
UI64 uint64 `json:"UI64" form:"UI64"`
UI uint `json:"UI" form:"UI"`
I64 int64 `json:"I64" form:"I64"`
F32 float32 `json:"F32" form:"F32"`
UI32 uint32 `json:"UI32" form:"UI32"`
I32 int32 `json:"I32" form:"I32"`
UI16 uint16 `json:"UI16" form:"UI16"`
I16 int16 `json:"I16" form:"I16"`
B bool `json:"B" form:"B"`
UI8 uint8 `json:"UI8" form:"UI8"`
I8 int8 `json:"I8" form:"I8"`
}
type Timestamp time.Time
type TA []Timestamp
type StringArray []string
type Struct struct {
Foo string
}
type Bar struct {
Baz int `json:"baz" query:"baz"`
}
func (t *Timestamp) UnmarshalParam(src string) error {
ts, err := time.Parse(time.RFC3339, src)
*t = Timestamp(ts)
return err
}
func (a *StringArray) UnmarshalParam(src string) error {
*a = StringArray(strings.Split(src, ","))
return nil
}
func (s *Struct) UnmarshalParam(src string) error {
*s = Struct{
Foo: src,
}
return nil
}
func (t bindTestStruct) GetCantSet() string {
return t.cantSet
}
var values = map[string][]string{
"I": {"0"},
"PtrI": {"0"},
"I8": {"8"},
"PtrI8": {"8"},
"I16": {"16"},
"PtrI16": {"16"},
"I32": {"32"},
"PtrI32": {"32"},
"I64": {"64"},
"PtrI64": {"64"},
"UI": {"0"},
"PtrUI": {"0"},
"UI8": {"8"},
"PtrUI8": {"8"},
"UI16": {"16"},
"PtrUI16": {"16"},
"UI32": {"32"},
"PtrUI32": {"32"},
"UI64": {"64"},
"PtrUI64": {"64"},
"B": {"true"},
"PtrB": {"true"},
"F32": {"32.5"},
"PtrF32": {"32.5"},
"F64": {"64.5"},
"PtrF64": {"64.5"},
"S": {"test"},
"PtrS": {"test"},
"cantSet": {"test"},
"T": {"2016-12-06T19:09:05+01:00"},
"Tptr": {"2016-12-06T19:09:05+01:00"},
"GoT": {"2016-12-06T19:09:05+01:00"},
"GoTptr": {"2016-12-06T19:09:05+01:00"},
"ST": {"bar"},
}
// ptr return pointer to value. This is useful as `v := []*int8{&int8(1)}` will not compile
func ptr[T any](value T) *T {
return &value
}
func TestToMultipleFields(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?id=1&ID=2", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
type Root struct {
ID int64 `query:"id"`
Child2 struct {
ID int64
}
Child1 struct {
ID int64 `query:"id"`
}
}
u := new(Root)
err := c.Bind(u)
if assert.NoError(t, err) {
assert.Equal(t, int64(1), u.ID) // perfectly reasonable
assert.Equal(t, int64(1), u.Child1.ID) // untagged struct containing tagged field gets filled (by tag)
assert.Equal(t, int64(0), u.Child2.ID) // untagged struct containing untagged field should not be bind
}
}
func TestBindJSON(t *testing.T) {
testBindOkay(t, strings.NewReader(userJSON), nil, MIMEApplicationJSON)
testBindOkay(t, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON)
testBindArrayOkay(t, strings.NewReader(usersJSON), nil, MIMEApplicationJSON)
testBindArrayOkay(t, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON)
testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{})
testBindError(t, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{})
}
func TestBindXML(t *testing.T) {
testBindOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML)
testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML)
testBindArrayOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML)
testBindArrayOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML)
testBindError(t, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New(""))
testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{})
testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{})
testBindOkay(t, strings.NewReader(userXML), nil, MIMETextXML)
testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMETextXML)
testBindError(t, strings.NewReader(invalidContent), MIMETextXML, errors.New(""))
testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{})
testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{})
}
func TestBindForm(t *testing.T) {
testBindOkay(t, strings.NewReader(userForm), nil, MIMEApplicationForm)
testBindOkay(t, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm)
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userForm))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(HeaderContentType, MIMEApplicationForm)
err := c.Bind(&[]struct{ Field string }{})
assert.Error(t, err)
}
func TestBindQueryParams(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
u := new(user)
err := c.Bind(u)
if assert.NoError(t, err) {
assert.Equal(t, 1, u.ID)
assert.Equal(t, "Jon Snow", u.Name)
}
}
func TestBindQueryParamsCaseInsensitive(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?ID=1&NAME=Jon+Snow", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
u := new(user)
err := c.Bind(u)
if assert.NoError(t, err) {
assert.Equal(t, 1, u.ID)
assert.Equal(t, "Jon Snow", u.Name)
}
}
func TestBindQueryParamsCaseSensitivePrioritized(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?id=1&ID=2&NAME=Jon+Snow&name=Jon+Doe", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
u := new(user)
err := c.Bind(u)
if assert.NoError(t, err) {
assert.Equal(t, 1, u.ID)
assert.Equal(t, "Jon Doe", u.Name)
}
}
func TestBindHeaderParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Name", "Jon Doe")
req.Header.Set("Id", "2")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
u := new(user)
err := BindHeaders(c, u)
if assert.NoError(t, err) {
assert.Equal(t, 2, u.ID)
assert.Equal(t, "Jon Doe", u.Name)
}
}
func TestBindHeaderParamBadType(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Id", "salamander")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
u := new(user)
err := BindHeaders(c, u)
assert.Error(t, err)
httpErr, ok := err.(*HTTPError)
if assert.True(t, ok) {
assert.Equal(t, http.StatusBadRequest, httpErr.Code)
}
}
func TestBindUnmarshalParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
T Timestamp `query:"ts"`
ST Struct
StWithTag struct {
Foo string `query:"st"`
}
TA []Timestamp `query:"ta"`
SA StringArray `query:"sa"`
}{}
err := c.Bind(&result)
ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC))
if assert.NoError(t, err) {
// assert.Equal( Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T)
assert.Equal(t, ts, result.T)
assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
assert.Equal(t, []Timestamp{ts, ts}, result.TA)
assert.Equal(t, Struct{""}, result.ST) // child struct does not have a field with matching tag
assert.Equal(t, "baz", result.StWithTag.Foo) // child struct has field with matching tag
}
}
func TestBindUnmarshalText(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
T time.Time `query:"ts"`
ST Struct
TA []time.Time `query:"ta"`
SA StringArray `query:"sa"`
}{}
err := c.Bind(&result)
ts := time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)
if assert.NoError(t, err) {
// assert.Equal(t, Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T)
assert.Equal(t, ts, result.T)
assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
assert.Equal(t, []time.Time{ts, ts}, result.TA)
assert.Equal(t, Struct{""}, result.ST) // field in child struct does not have tag
}
}
func TestBindUnmarshalParamPtr(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
Tptr *Timestamp `query:"ts"`
}{}
err := c.Bind(&result)
if assert.NoError(t, err) {
assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), *result.Tptr)
}
}
func TestBindUnmarshalParamAnonymousFieldPtr(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?baz=1", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
*Bar
}{&Bar{}}
err := c.Bind(&result)
if assert.NoError(t, err) {
assert.Equal(t, 1, result.Baz)
}
}
func TestBindUnmarshalParamAnonymousFieldPtrNil(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?baz=1", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
*Bar
}{}
err := c.Bind(&result)
if assert.NoError(t, err) {
assert.Nil(t, result.Bar)
}
}
func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, `/?bar={"baz":100}&baz=1`, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
*Bar `json:"bar" query:"bar"`
}{&Bar{}}
err := c.Bind(&result)
assert.Contains(t, err.Error(), "query/param/form tags are not allowed with anonymous struct field")
}
func TestBindUnmarshalTextPtr(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
Tptr *time.Time `query:"ts"`
}{}
err := c.Bind(&result)
if assert.NoError(t, err) {
assert.Equal(t, time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC), *result.Tptr)
}
}
func TestBindMultipartForm(t *testing.T) {
bodyBuffer := new(bytes.Buffer)
mw := multipart.NewWriter(bodyBuffer)
mw.WriteField("id", "1")
mw.WriteField("name", "Jon Snow")
mw.Close()
body := bodyBuffer.Bytes()
testBindOkay(t, bytes.NewReader(body), nil, mw.FormDataContentType())
testBindOkay(t, bytes.NewReader(body), dummyQuery, mw.FormDataContentType())
}
func TestBindUnsupportedMediaType(t *testing.T) {
testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{})
}
func TestDefaultBinder_bindDataToMap(t *testing.T) {
exampleData := map[string][]string{
"multiple": {"1", "2"},
"single": {"3"},
}
t.Run("ok, bind to map[string]string", func(t *testing.T) {
dest := map[string]string{}
assert.NoError(t, bindData(&dest, exampleData, "param", nil))
assert.Equal(t,
map[string]string{
"multiple": "1",
"single": "3",
},
dest,
)
})
t.Run("ok, bind to map[string]string with nil map", func(t *testing.T) {
var dest map[string]string
assert.NoError(t, bindData(&dest, exampleData, "param", nil))
assert.Equal(t,
map[string]string{
"multiple": "1",
"single": "3",
},
dest,
)
})
t.Run("ok, bind to map[string][]string", func(t *testing.T) {
dest := map[string][]string{}
assert.NoError(t, bindData(&dest, exampleData, "param", nil))
assert.Equal(t,
map[string][]string{
"multiple": {"1", "2"},
"single": {"3"},
},
dest,
)
})
t.Run("ok, bind to map[string][]string with nil map", func(t *testing.T) {
var dest map[string][]string
assert.NoError(t, bindData(&dest, exampleData, "param", nil))
assert.Equal(t,
map[string][]string{
"multiple": {"1", "2"},
"single": {"3"},
},
dest,
)
})
t.Run("ok, bind to map[string]interface", func(t *testing.T) {
dest := map[string]any{}
assert.NoError(t, bindData(&dest, exampleData, "param", nil))
assert.Equal(t,
map[string]any{
"multiple": "1",
"single": "3",
},
dest,
)
})
t.Run("ok, bind to map[string]interface with nil map", func(t *testing.T) {
var dest map[string]any
assert.NoError(t, bindData(&dest, exampleData, "param", nil))
assert.Equal(t,
map[string]any{
"multiple": "1",
"single": "3",
},
dest,
)
})
t.Run("ok, bind to map[string]int skips", func(t *testing.T) {
dest := map[string]int{}
assert.NoError(t, bindData(&dest, exampleData, "param", nil))
assert.Equal(t, map[string]int{}, dest)
})
t.Run("ok, bind to map[string]int skips with nil map", func(t *testing.T) {
var dest map[string]int
assert.NoError(t, bindData(&dest, exampleData, "param", nil))
assert.Equal(t, map[string]int(nil), dest)
})
t.Run("ok, bind to map[string][]int skips", func(t *testing.T) {
dest := map[string][]int{}
assert.NoError(t, bindData(&dest, exampleData, "param", nil))
assert.Equal(t, map[string][]int{}, dest)
})
t.Run("ok, bind to map[string][]int skips with nil map", func(t *testing.T) {
var dest map[string][]int
assert.NoError(t, bindData(&dest, exampleData, "param", nil))
assert.Equal(t, map[string][]int(nil), dest)
})
}
func TestBindbindData(t *testing.T) {
ts := new(bindTestStruct)
err := bindData(ts, values, "form", nil)
assert.NoError(t, err)
assert.Equal(t, 0, ts.I)
assert.Equal(t, int8(0), ts.I8)
assert.Equal(t, int16(0), ts.I16)
assert.Equal(t, int32(0), ts.I32)
assert.Equal(t, int64(0), ts.I64)
assert.Equal(t, uint(0), ts.UI)
assert.Equal(t, uint8(0), ts.UI8)
assert.Equal(t, uint16(0), ts.UI16)
assert.Equal(t, uint32(0), ts.UI32)
assert.Equal(t, uint64(0), ts.UI64)
assert.Equal(t, false, ts.B)
assert.Equal(t, float32(0), ts.F32)
assert.Equal(t, float64(0), ts.F64)
assert.Equal(t, "", ts.S)
assert.Equal(t, "", ts.cantSet)
}
func TestBindParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.InitializeRoute(
&RouteInfo{Path: "/users/:id/:name"},
&PathValues{
{Name: "id", Value: "1"},
{Name: "name", Value: "Jon Snow"},
},
)
u := new(user)
err := c.Bind(u)
if assert.NoError(t, err) {
assert.Equal(t, 1, u.ID)
assert.Equal(t, "Jon Snow", u.Name)
}
// Second test for the absence of a param
c2 := e.NewContext(req, rec)
c2.InitializeRoute(
&RouteInfo{Path: "/users/:id"},
&PathValues{
{Name: "id", Value: "1"},
},
)
u = new(user)
err = c2.Bind(u)
if assert.NoError(t, err) {
assert.Equal(t, 1, u.ID)
assert.Equal(t, "", u.Name)
}
// Bind something with param and post data payload
body := bytes.NewBufferString(`{ "name": "Jon Snow" }`)
e2 := New()
req2 := httptest.NewRequest(http.MethodPost, "/", body)
req2.Header.Set(HeaderContentType, MIMEApplicationJSON)
rec2 := httptest.NewRecorder()
c3 := e2.NewContext(req2, rec2)
c3.InitializeRoute(
&RouteInfo{Path: "/users/:id"},
&PathValues{
{Name: "id", Value: "1"},
},
)
u = new(user)
err = c3.Bind(u)
if assert.NoError(t, err) {
assert.Equal(t, 1, u.ID)
assert.Equal(t, "Jon Snow", u.Name)
}
}
func TestBindUnmarshalTypeError(t *testing.T) {
body := bytes.NewBufferString(`{ "id": "text" }`)
e := New()
req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set(HeaderContentType, MIMEApplicationJSON)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
u := new(user)
err := c.Bind(u)
assert.EqualError(t, err, `code=400, message=Bad Request, err=json: cannot unmarshal string into Go struct field user.id of type int`)
}
func TestBindSetWithProperType(t *testing.T) {
ts := new(bindTestStruct)
typ := reflect.TypeOf(ts).Elem()
val := reflect.ValueOf(ts).Elem()
for i := 0; i < typ.NumField(); i++ {
typeField := typ.Field(i)
structField := val.Field(i)
if !structField.CanSet() {
continue
}
if len(values[typeField.Name]) == 0 {
continue
}
val := values[typeField.Name][0]
err := setWithProperType(typeField.Type.Kind(), val, structField)
assert.NoError(t, err)
}
assertBindTestStruct(t, ts)
type foo struct {
Bar bytes.Buffer
}
v := &foo{}
typ = reflect.TypeOf(v).Elem()
val = reflect.ValueOf(v).Elem()
assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
}
func BenchmarkBindbindDataWithTags(b *testing.B) {
b.ReportAllocs()
ts := new(bindTestStructWithTags)
var err error
b.ResetTimer()
for i := 0; i < b.N; i++ {
err = bindData(ts, values, "form", nil)
}
assert.NoError(b, err)
assertBindTestStruct(b, (*bindTestStruct)(ts))
}
func assertBindTestStruct(tb testing.TB, ts *bindTestStruct) {
assert.Equal(tb, 0, ts.I)
assert.Equal(tb, int8(8), ts.I8)
assert.Equal(tb, int16(16), ts.I16)
assert.Equal(tb, int32(32), ts.I32)
assert.Equal(tb, int64(64), ts.I64)
assert.Equal(tb, uint(0), ts.UI)
assert.Equal(tb, uint8(8), ts.UI8)
assert.Equal(tb, uint16(16), ts.UI16)
assert.Equal(tb, uint32(32), ts.UI32)
assert.Equal(tb, uint64(64), ts.UI64)
assert.Equal(tb, true, ts.B)
assert.Equal(tb, float32(32.5), ts.F32)
assert.Equal(tb, float64(64.5), ts.F64)
assert.Equal(tb, "test", ts.S)
assert.Equal(tb, "", ts.GetCantSet())
}
func testBindOkay(t *testing.T, r io.Reader, query url.Values, ctype string) {
e := New()
path := "/"
if len(query) > 0 {
path += "?" + query.Encode()
}
req := httptest.NewRequest(http.MethodPost, path, r)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(HeaderContentType, ctype)
u := new(user)
err := c.Bind(u)
if assert.Equal(t, nil, err) {
assert.Equal(t, 1, u.ID)
assert.Equal(t, "Jon Snow", u.Name)
}
}
func testBindArrayOkay(t *testing.T, r io.Reader, query url.Values, ctype string) {
e := New()
path := "/"
if len(query) > 0 {
path += "?" + query.Encode()
}
req := httptest.NewRequest(http.MethodPost, path, r)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(HeaderContentType, ctype)
u := []user{}
err := c.Bind(&u)
if assert.NoError(t, err) {
assert.Equal(t, 1, len(u))
assert.Equal(t, 1, u[0].ID)
assert.Equal(t, "Jon Snow", u[0].Name)
}
}
func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal error) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", r)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(HeaderContentType, ctype)
u := new(user)
err := c.Bind(u)
switch {
case strings.HasPrefix(ctype, MIMEApplicationJSON), strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML),
strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
if assert.IsType(t, new(HTTPError), err) {
assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code)
assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap())
}
default:
if assert.IsType(t, new(HTTPError), err) {
assert.Equal(t, ErrUnsupportedMediaType, err)
assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap())
}
}
}
func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
// tests to check binding behaviour when multiple sources (path params, query params and request body) are in use
// binding is done in steps and one source could overwrite previous source bound data
// these tests are to document this behaviour and detect further possible regressions when bind implementation is changed
type Opts struct {
Node string `json:"node" form:"node" query:"node" param:"node"`
Lang string
ID int `json:"id" form:"id" query:"id"`
}
var testCases = []struct {
givenContent io.Reader
whenBindTarget any
expect any
name string
givenURL string
givenMethod string
expectError string
whenNoPathValues bool
}{
{
name: "ok, POST bind to struct with: path param + query param + body",
givenMethod: http.MethodPost,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`{"id": 1}`),
expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used, node is filled from path
},
{
name: "ok, PUT bind to struct with: path param + query param + body",
givenMethod: http.MethodPut,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`{"id": 1}`),
expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used
},
{
name: "ok, GET bind to struct with: path param + query param + body",
givenMethod: http.MethodGet,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`{"id": 1}`),
expect: &Opts{ID: 1, Node: "xxx"}, // query overwrites previous path value
},
{
name: "ok, GET bind to struct with: path param + query param + body",
givenMethod: http.MethodGet,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
expect: &Opts{ID: 1, Node: "zzz"}, // body is bound last and overwrites previous (path,query) values
},
{
name: "ok, DELETE bind to struct with: path param + query param + body",
givenMethod: http.MethodDelete,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is bound after query params
},
{
name: "ok, POST bind to struct with: path param + body",
givenMethod: http.MethodPost,
givenURL: "/api/real_node/endpoint",
givenContent: strings.NewReader(`{"id": 1}`),
expect: &Opts{ID: 1, Node: "node_from_path"},
},
{
name: "ok, POST bind to struct with path + query + body = body has priority",
givenMethod: http.MethodPost,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
expect: &Opts{ID: 1, Node: "zzz"}, // field value from content has higher priority
},
{
name: "nok, POST body bind failure",
givenMethod: http.MethodPost,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`{`),
expect: &Opts{ID: 0, Node: "node_from_path"}, // query binding has already modified bind target
expectError: "code=400, message=Bad Request, err=unexpected EOF",
},
{
name: "nok, GET with body bind failure when types are not convertible",
givenMethod: http.MethodGet,
givenURL: "/api/real_node/endpoint?id=nope",
givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
expect: &Opts{ID: 0, Node: "node_from_path"}, // path params binding has already modified bind target
expectError: `code=400, message=Bad Request, err=strconv.ParseInt: parsing "nope": invalid syntax`,
},
{
name: "nok, GET body bind failure - trying to bind json array to struct",
givenMethod: http.MethodGet,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`[{"id": 1}]`),
expect: &Opts{ID: 0, Node: "xxx"}, // query binding has already modified bind target
expectError: `code=400, message=Bad Request, err=json: cannot unmarshal array into Go value of type echo.Opts`,
},
{ // query param is ignored as we do not know where exactly to bind it in slice
name: "ok, GET bind to struct slice, ignore query param",
givenMethod: http.MethodGet,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`[{"id": 1}]`),
whenNoPathValues: true,
whenBindTarget: &[]Opts{},
expect: &[]Opts{
{ID: 1, Node: ""},
},
},
{ // binding query params interferes with body. b.BindBody() should be used to bind only body to slice
name: "ok, POST binding to slice should not be affected query params types",
givenMethod: http.MethodPost,
givenURL: "/api/real_node/endpoint?id=nope&node=xxx",
givenContent: strings.NewReader(`[{"id": 1}]`),
whenNoPathValues: true,
whenBindTarget: &[]Opts{},
expect: &[]Opts{{ID: 1}},
expectError: "",
},
{ // path param is ignored as we do not know where exactly to bind it in slice
name: "ok, GET bind to struct slice, ignore path param",
givenMethod: http.MethodGet,
givenURL: "/api/real_node/endpoint?node=xxx",
givenContent: strings.NewReader(`[{"id": 1}]`),
whenBindTarget: &[]Opts{},
expect: &[]Opts{
{ID: 1, Node: ""},
},
},
{
name: "ok, GET body bind json array to slice",
givenMethod: http.MethodGet,
givenURL: "/api/real_node/endpoint",
givenContent: strings.NewReader(`[{"id": 1}]`),
whenNoPathValues: true,
whenBindTarget: &[]Opts{},
expect: &[]Opts{{ID: 1, Node: ""}},
expectError: "",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
// assume route we are testing is "/api/:node/endpoint?some_query_params=here"
req := httptest.NewRequest(tc.givenMethod, tc.givenURL, tc.givenContent)
req.Header.Set(HeaderContentType, MIMEApplicationJSON)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if !tc.whenNoPathValues {
c.SetPathValues(PathValues{
{Name: "node", Value: "node_from_path"},
})
}
var bindTarget any
if tc.whenBindTarget != nil {
bindTarget = tc.whenBindTarget
} else {
bindTarget = &Opts{}
}
b := new(DefaultBinder)
err := b.Bind(c, bindTarget)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, bindTarget)
})
}
}
func TestDefaultBinder_BindBody(t *testing.T) {
// tests to check binding behaviour when multiple sources (path params, query params and request body) are in use
// generally when binding from request body - URL and path params are ignored - unless form is being bound.
// these tests are to document this behaviour and detect further possible regressions when bind implementation is changed
type Node struct {
Node string `json:"node" xml:"node" form:"node" query:"node" param:"node"`
ID int `json:"id" xml:"id" form:"id" query:"id"`
}
type Nodes struct {
Nodes []Node `xml:"node" form:"node"`
}
var testCases = []struct {
givenContent io.Reader
whenBindTarget any
expect any
name string
givenURL string
givenMethod string
givenContentType string
expectError string
whenNoPathValues bool
whenChunkedBody bool
}{
{
name: "ok, JSON POST bind to struct with: path + query + empty field in body",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodPost,
givenContentType: MIMEApplicationJSON,
givenContent: strings.NewReader(`{"id": 1}`),
expect: &Node{ID: 1, Node: ""}, // path params or query params should not interfere with body
},
{
name: "ok, JSON POST bind to struct with: path + query + body",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodPost,
givenContentType: MIMEApplicationJSON,
givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority
},
{
name: "ok, JSON POST body bind json array to slice (has matching path/query params)",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodPost,
givenContentType: MIMEApplicationJSON,
givenContent: strings.NewReader(`[{"id": 1}]`),
whenNoPathValues: true,
whenBindTarget: &[]Node{},
expect: &[]Node{{ID: 1, Node: ""}},
expectError: "",
},
{ // rare case as GET is not usually used to send request body
name: "ok, JSON GET bind to struct with: path + query + empty field in body",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodGet,
givenContentType: MIMEApplicationJSON,
givenContent: strings.NewReader(`{"id": 1}`),
expect: &Node{ID: 1, Node: ""}, // path params or query params should not interfere with body
},
{ // rare case as GET is not usually used to send request body
name: "ok, JSON GET bind to struct with: path + query + body",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodGet,
givenContentType: MIMEApplicationJSON,
givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority
},
{
name: "nok, JSON POST body bind failure",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodPost,
givenContentType: MIMEApplicationJSON,
givenContent: strings.NewReader(`{`),
expect: &Node{ID: 0, Node: ""},
expectError: "code=400, message=Bad Request, err=unexpected EOF",
},
{
name: "ok, XML POST bind to struct with: path + query + empty body",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodPost,
givenContentType: MIMEApplicationXML,
givenContent: strings.NewReader(`1 yyy `),
expect: &Node{ID: 1, Node: "yyy"},
},
{
name: "ok, XML POST bind array to slice with: path + query + body",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodPost,
givenContentType: MIMEApplicationXML,
givenContent: strings.NewReader(`1 yyy `),
whenBindTarget: &Nodes{},
expect: &Nodes{Nodes: []Node{{ID: 1, Node: "yyy"}}},
},
{
name: "nok, XML POST bind failure",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodPost,
givenContentType: MIMEApplicationXML,
givenContent: strings.NewReader(`<`),
expect: &Node{ID: 0, Node: ""},
expectError: "code=400, message=Bad Request, err=XML syntax error on line 1: unexpected EOF",
},
{
name: "ok, FORM POST bind to struct with: path + query + body",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodPost,
givenContentType: MIMEApplicationForm,
givenContent: strings.NewReader(`id=1&node=yyy`),
expect: &Node{ID: 1, Node: "yyy"},
},
{
// NB: form values are taken from BOTH body and query for POST/PUT/PATCH by standard library implementation
// See: https://golang.org/pkg/net/http/#Request.ParseForm
name: "ok, FORM POST bind to struct with: path + query + empty field in body",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodPost,
givenContentType: MIMEApplicationForm,
givenContent: strings.NewReader(`id=1`),
expect: &Node{ID: 1, Node: "xxx"},
},
{
// NB: form values are taken from query by standard library implementation
// See: https://golang.org/pkg/net/http/#Request.ParseForm
name: "ok, FORM GET bind to struct with: path + query + empty field in body",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodGet,
givenContentType: MIMEApplicationForm,
givenContent: strings.NewReader(`id=1`),
expect: &Node{ID: 0, Node: "xxx"}, // 'xxx' is taken from URL and body is not used with GET by implementation
},
{
name: "nok, unsupported content type",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodPost,
givenContentType: MIMETextPlain,
givenContent: strings.NewReader(``),
expect: &Node{ID: 0, Node: ""},
expectError: "code=415, message=Unsupported Media Type",
},
// FIXME: REASON in Go 1.24 and earlier http.NoBody would result ContentLength=-1
// but as of Go 1.25 http.NoBody would result ContentLength=0
// I am too lazy to bother documenting this as 2 version specific tests.
//{
// name: "nok, JSON POST with http.NoBody",
// givenURL: "/api/real_node/endpoint?node=xxx",
// givenMethod: http.MethodPost,
// givenContentType: MIMEApplicationJSON,
// givenContent: http.NoBody,
// expect: &Node{ID: 0, Node: ""},
// expectError: "code=400, message=EOF, internal=EOF",
//},
{
name: "ok, JSON POST with empty body",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodPost,
givenContentType: MIMEApplicationJSON,
givenContent: strings.NewReader(""),
expect: &Node{ID: 0, Node: ""},
},
{
name: "ok, JSON POST bind to struct with: path + query + chunked body",
givenURL: "/api/real_node/endpoint?node=xxx",
givenMethod: http.MethodPost,
givenContentType: MIMEApplicationJSON,
givenContent: httputil.NewChunkedReader(strings.NewReader("18\r\n" + `{"id": 1, "node": "zzz"}` + "\r\n0\r\n")),
whenChunkedBody: true,
expect: &Node{ID: 1, Node: "zzz"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
// assume route we are testing is "/api/:node/endpoint?some_query_params=here"
req := httptest.NewRequest(tc.givenMethod, tc.givenURL, tc.givenContent)
switch tc.givenContentType {
case MIMEApplicationXML:
req.Header.Set(HeaderContentType, MIMEApplicationXML)
case MIMEApplicationForm:
req.Header.Set(HeaderContentType, MIMEApplicationForm)
case MIMEApplicationJSON:
req.Header.Set(HeaderContentType, MIMEApplicationJSON)
}
if tc.whenChunkedBody {
req.ContentLength = -1
req.TransferEncoding = append(req.TransferEncoding, "chunked")
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if !tc.whenNoPathValues {
c.SetPathValues(PathValues{
{Name: "node", Value: "real_node"},
})
}
var bindTarget any
if tc.whenBindTarget != nil {
bindTarget = tc.whenBindTarget
} else {
bindTarget = &Node{}
}
err := BindBody(c, bindTarget)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, bindTarget)
})
}
}
func testBindURL(queryString string, target any) error {
e := New()
req := httptest.NewRequest(http.MethodGet, queryString, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
return c.Bind(target)
}
type unixTimestamp struct {
Time time.Time
}
func (t *unixTimestamp) UnmarshalParam(param string) error {
n, err := strconv.ParseInt(param, 10, 64)
if err != nil {
return fmt.Errorf("'%s' is not an integer", param)
}
*t = unixTimestamp{Time: time.Unix(n, 0)}
return err
}
type IntArrayA []int
// UnmarshalParam converts value to *Int64Slice. This allows the API to accept
// a comma-separated list of integers as a query parameter.
func (i *IntArrayA) UnmarshalParam(value string) error {
var values = strings.Split(value, ",")
var numbers = make([]int, 0, len(values))
for _, v := range values {
n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return fmt.Errorf("'%s' is not an integer", v)
}
numbers = append(numbers, int(n))
}
*i = append(*i, numbers...)
return nil
}
func TestBindUnmarshalParamExtras(t *testing.T) {
// this test documents how bind handles `BindUnmarshaler` interface:
// NOTE: BindUnmarshaler chooses first input value to be bound.
t.Run("nok, unmarshalling fails", func(t *testing.T) {
result := struct {
V unixTimestamp `query:"t"`
}{}
err := testBindURL("/?t=xxxx", &result)
assert.EqualError(t, err, `code=400, message=Bad Request, err='xxxx' is not an integer`)
})
t.Run("ok, target is struct", func(t *testing.T) {
result := struct {
V unixTimestamp `query:"t"`
}{}
err := testBindURL("/?t=1710095540&t=1710095541", &result)
assert.NoError(t, err)
expect := unixTimestamp{
Time: time.Unix(1710095540, 0),
}
assert.Equal(t, expect, result.V)
})
t.Run("ok, target is an alias to slice and is nil, append only values from first", func(t *testing.T) {
result := struct {
V IntArrayA `query:"a"`
}{}
err := testBindURL("/?a=1,2,3&a=4,5,6", &result)
assert.NoError(t, err)
assert.Equal(t, IntArrayA([]int{1, 2, 3}), result.V)
})
t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) {
result := struct {
V IntArrayA `query:"a"`
}{}
err := testBindURL("/?a=1,2", &result)
assert.NoError(t, err)
assert.Equal(t, IntArrayA([]int{1, 2}), result.V)
})
t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) {
result := struct {
V *IntArrayA `query:"a"`
}{}
err := testBindURL("/?a=1&a=4,5,6", &result)
assert.NoError(t, err)
var expected = IntArrayA([]int{1})
assert.Equal(t, &expected, result.V)
})
t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) {
result := struct {
V *IntArrayA `query:"a"`
}{}
result.V = new(IntArrayA) // NOT nil
err := testBindURL("/?a=1&a=4,5,6", &result)
assert.NoError(t, err)
var expected = IntArrayA([]int{1})
assert.Equal(t, &expected, result.V)
})
}
type unixTimestampLast struct {
Time time.Time
}
// this is silly example for `bindMultipleUnmarshaler` for type that uses last input value for unmarshalling
func (t *unixTimestampLast) UnmarshalParams(params []string) error {
lastInput := params[len(params)-1]
n, err := strconv.ParseInt(lastInput, 10, 64)
if err != nil {
return fmt.Errorf("'%s' is not an integer", lastInput)
}
*t = unixTimestampLast{Time: time.Unix(n, 0)}
return err
}
type IntArrayB []int
func (i *IntArrayB) UnmarshalParams(params []string) error {
var numbers = make([]int, 0, len(params))
for _, param := range params {
var values = strings.Split(param, ",")
for _, v := range values {
n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return fmt.Errorf("'%s' is not an integer", v)
}
numbers = append(numbers, int(n))
}
}
*i = append(*i, numbers...)
return nil
}
func TestBindUnmarshalParams(t *testing.T) {
// this test documents how bind handles `bindMultipleUnmarshaler` interface:
t.Run("nok, unmarshalling fails", func(t *testing.T) {
result := struct {
V unixTimestampLast `query:"t"`
}{}
err := testBindURL("/?t=xxxx", &result)
assert.EqualError(t, err, "code=400, message=Bad Request, err='xxxx' is not an integer")
})
t.Run("ok, target is struct", func(t *testing.T) {
result := struct {
V unixTimestampLast `query:"t"`
}{}
err := testBindURL("/?t=1710095540&t=1710095541", &result)
assert.NoError(t, err)
expect := unixTimestampLast{
Time: time.Unix(1710095541, 0),
}
assert.Equal(t, expect, result.V)
})
t.Run("ok, target is an alias to slice and is nil, append multiple inputs", func(t *testing.T) {
result := struct {
V IntArrayB `query:"a"`
}{}
err := testBindURL("/?a=1,2,3&a=4,5,6", &result)
assert.NoError(t, err)
assert.Equal(t, IntArrayB([]int{1, 2, 3, 4, 5, 6}), result.V)
})
t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) {
result := struct {
V IntArrayB `query:"a"`
}{}
err := testBindURL("/?a=1,2", &result)
assert.NoError(t, err)
assert.Equal(t, IntArrayB([]int{1, 2}), result.V)
})
t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) {
result := struct {
V *IntArrayB `query:"a"`
}{}
err := testBindURL("/?a=1&a=4,5,6", &result)
assert.NoError(t, err)
var expected = IntArrayB([]int{1, 4, 5, 6})
assert.Equal(t, &expected, result.V)
})
t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) {
result := struct {
V *IntArrayB `query:"a"`
}{}
result.V = new(IntArrayB) // NOT nil
err := testBindURL("/?a=1&a=4,5,6", &result)
assert.NoError(t, err)
var expected = IntArrayB([]int{1, 4, 5, 6})
assert.Equal(t, &expected, result.V)
})
}
func TestBindInt8(t *testing.T) {
t.Run("nok, binding fails", func(t *testing.T) {
type target struct {
V int8 `query:"v"`
}
p := target{}
err := testBindURL("/?v=x&v=2", &p)
assert.EqualError(t, err, `code=400, message=Bad Request, err=strconv.ParseInt: parsing "x": invalid syntax`)
})
t.Run("nok, int8 embedded in struct", func(t *testing.T) {
type target struct {
int8 `query:"v"` // embedded field is `Anonymous`. We can only set public fields
}
p := target{}
err := testBindURL("/?v=1&v=2", &p)
assert.NoError(t, err)
assert.Equal(t, target{0}, p)
})
t.Run("nok, pointer to int8 embedded in struct", func(t *testing.T) {
type target struct {
*int8 `query:"v"` // embedded field is `Anonymous`. We can only set public fields
}
p := target{}
err := testBindURL("/?v=1&v=2", &p)
assert.NoError(t, err)
assert.Equal(t, target{int8: nil}, p)
})
t.Run("ok, bind int8 as struct field", func(t *testing.T) {
type target struct {
V int8 `query:"v"`
}
p := target{V: 127}
err := testBindURL("/?v=1&v=2", &p)
assert.NoError(t, err)
assert.Equal(t, target{V: 1}, p)
})
t.Run("ok, bind pointer to int8 as struct field, value is nil", func(t *testing.T) {
type target struct {
V *int8 `query:"v"`
}
p := target{}
err := testBindURL("/?v=1&v=2", &p)
assert.NoError(t, err)
assert.Equal(t, target{V: ptr(int8(1))}, p)
})
t.Run("ok, bind pointer to int8 as struct field, value is set", func(t *testing.T) {
type target struct {
V *int8 `query:"v"`
}
p := target{V: ptr(int8(127))}
err := testBindURL("/?v=1&v=2", &p)
assert.NoError(t, err)
assert.Equal(t, target{V: ptr(int8(1))}, p)
})
t.Run("ok, bind int8 slice as struct field, value is nil", func(t *testing.T) {
type target struct {
V []int8 `query:"v"`
}
p := target{}
err := testBindURL("/?v=1&v=2", &p)
assert.NoError(t, err)
assert.Equal(t, target{V: []int8{1, 2}}, p)
})
t.Run("ok, bind slice of int8 as struct field, value is set", func(t *testing.T) {
type target struct {
V []int8 `query:"v"`
}
p := target{V: []int8{111}}
err := testBindURL("/?v=1&v=2", &p)
assert.NoError(t, err)
assert.Equal(t, target{V: []int8{1, 2}}, p)
})
t.Run("ok, bind slice of pointer to int8 as struct field, value is set", func(t *testing.T) {
type target struct {
V []*int8 `query:"v"`
}
p := target{V: []*int8{ptr(int8(127))}}
err := testBindURL("/?v=1&v=2", &p)
assert.NoError(t, err)
assert.Equal(t, target{V: []*int8{ptr(int8(1)), ptr(int8(2))}}, p)
})
t.Run("ok, bind pointer to slice of int8 as struct field, value is set", func(t *testing.T) {
type target struct {
V *[]int8 `query:"v"`
}
p := target{V: &[]int8{111}}
err := testBindURL("/?v=1&v=2", &p)
assert.NoError(t, err)
assert.Equal(t, target{V: &[]int8{1, 2}}, p)
})
}
func TestBindMultipartFormFiles(t *testing.T) {
file1 := createTestFormFile("file", "file1.txt")
file11 := createTestFormFile("file", "file11.txt")
file2 := createTestFormFile("file2", "file2.txt")
filesA := createTestFormFile("files", "filesA.txt")
filesB := createTestFormFile("files", "filesB.txt")
t.Run("nok, can not bind to multipart file struct", func(t *testing.T) {
var target struct {
File multipart.FileHeader `form:"file"`
}
err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored
assert.EqualError(t, err, `code=400, message=Bad Request, err=binding to multipart.FileHeader struct is not supported, use pointer to struct`)
})
t.Run("ok, bind single multipart file to pointer to multipart file", func(t *testing.T) {
var target struct {
File *multipart.FileHeader `form:"file"`
}
err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored
assert.NoError(t, err)
assertMultipartFileHeader(t, target.File, file1)
})
t.Run("ok, bind multiple multipart files to pointer to multipart file", func(t *testing.T) {
var target struct {
File *multipart.FileHeader `form:"file"`
}
err := bindMultipartFiles(t, &target, file1, file11)
assert.NoError(t, err)
assertMultipartFileHeader(t, target.File, file1) // should choose first one
})
t.Run("ok, bind multiple multipart files to slice of multipart file", func(t *testing.T) {
var target struct {
Files []multipart.FileHeader `form:"files"`
}
err := bindMultipartFiles(t, &target, filesA, filesB, file1)
assert.NoError(t, err)
assert.Len(t, target.Files, 2)
assertMultipartFileHeader(t, &target.Files[0], filesA)
assertMultipartFileHeader(t, &target.Files[1], filesB)
})
t.Run("ok, bind multiple multipart files to slice of pointer to multipart file", func(t *testing.T) {
var target struct {
Files []*multipart.FileHeader `form:"files"`
}
err := bindMultipartFiles(t, &target, filesA, filesB, file1)
assert.NoError(t, err)
assert.Len(t, target.Files, 2)
assertMultipartFileHeader(t, target.Files[0], filesA)
assertMultipartFileHeader(t, target.Files[1], filesB)
})
}
type testFormFile struct {
Fieldname string
Filename string
Content []byte
}
func createTestFormFile(formFieldName string, filename string) testFormFile {
return testFormFile{
Fieldname: formFieldName,
Filename: filename,
Content: []byte(strings.Repeat(filename, 10)),
}
}
func bindMultipartFiles(t *testing.T, target any, files ...testFormFile) error {
var body bytes.Buffer
mw := multipart.NewWriter(&body)
for _, file := range files {
fw, err := mw.CreateFormFile(file.Fieldname, file.Filename)
assert.NoError(t, err)
n, err := fw.Write(file.Content)
assert.NoError(t, err)
assert.Equal(t, len(file.Content), n)
}
err := mw.Close()
assert.NoError(t, err)
req, err := http.NewRequest(http.MethodPost, "/", &body)
assert.NoError(t, err)
req.Header.Set("Content-Type", mw.FormDataContentType())
rec := httptest.NewRecorder()
e := New()
c := e.NewContext(req, rec)
return c.Bind(target)
}
func assertMultipartFileHeader(t *testing.T, fh *multipart.FileHeader, file testFormFile) {
assert.Equal(t, file.Filename, fh.Filename)
assert.Equal(t, int64(len(file.Content)), fh.Size)
fl, err := fh.Open()
assert.NoError(t, err)
body, err := io.ReadAll(fl)
assert.NoError(t, err)
assert.Equal(t, string(file.Content), string(body))
err = fl.Close()
assert.NoError(t, err)
}
func TestTimeFormatBinding(t *testing.T) {
type TestStruct struct {
DateTimeLocal time.Time `form:"datetime_local" format:"2006-01-02T15:04"`
Date time.Time `query:"date" format:"2006-01-02"`
CustomFormat time.Time `form:"custom" format:"01/02/2006 15:04:05"`
DefaultTime time.Time `form:"default_time"` // No format tag - should use default parsing
PtrTime *time.Time `query:"ptr_time" format:"2006-01-02"`
}
testCases := []struct {
name string
contentType string
data string
queryParams string
expect TestStruct
expectError bool
}{
{
name: "ok, datetime-local format binding",
contentType: MIMEApplicationForm,
data: "datetime_local=2023-12-25T14:30&default_time=2023-12-25T14:30:45Z",
expect: TestStruct{
DateTimeLocal: time.Date(2023, 12, 25, 14, 30, 0, 0, time.UTC),
DefaultTime: time.Date(2023, 12, 25, 14, 30, 45, 0, time.UTC),
},
},
{
name: "ok, date format binding via query params",
queryParams: "?date=2023-01-15&ptr_time=2023-02-20",
expect: TestStruct{
Date: time.Date(2023, 1, 15, 0, 0, 0, 0, time.UTC),
PtrTime: &time.Time{},
},
},
{
name: "ok, custom format via form data",
contentType: MIMEApplicationForm,
data: "custom=12/25/2023 14:30:45",
expect: TestStruct{
CustomFormat: time.Date(2023, 12, 25, 14, 30, 45, 0, time.UTC),
},
},
{
name: "nok, invalid format should fail",
contentType: MIMEApplicationForm,
data: "datetime_local=invalid-date",
expectError: true,
},
{
name: "nok, wrong format should fail",
contentType: MIMEApplicationForm,
data: "datetime_local=2023-12-25", // Missing time part
expectError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
var req *http.Request
if tc.contentType == MIMEApplicationJSON {
req = httptest.NewRequest(http.MethodPost, "/"+tc.queryParams, strings.NewReader(tc.data))
req.Header.Set(HeaderContentType, tc.contentType)
} else if tc.contentType == MIMEApplicationForm {
req = httptest.NewRequest(http.MethodPost, "/"+tc.queryParams, strings.NewReader(tc.data))
req.Header.Set(HeaderContentType, tc.contentType)
} else {
req = httptest.NewRequest(http.MethodGet, "/"+tc.queryParams, nil)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
var result TestStruct
err := c.Bind(&result)
if tc.expectError {
assert.Error(t, err)
return
}
assert.NoError(t, err)
// Check individual fields since time comparison can be tricky
if !tc.expect.DateTimeLocal.IsZero() {
assert.True(t, tc.expect.DateTimeLocal.Equal(result.DateTimeLocal),
"DateTimeLocal: expected %v, got %v", tc.expect.DateTimeLocal, result.DateTimeLocal)
}
if !tc.expect.Date.IsZero() {
assert.True(t, tc.expect.Date.Equal(result.Date),
"Date: expected %v, got %v", tc.expect.Date, result.Date)
}
if !tc.expect.CustomFormat.IsZero() {
assert.True(t, tc.expect.CustomFormat.Equal(result.CustomFormat),
"CustomFormat: expected %v, got %v", tc.expect.CustomFormat, result.CustomFormat)
}
if !tc.expect.DefaultTime.IsZero() {
assert.True(t, tc.expect.DefaultTime.Equal(result.DefaultTime),
"DefaultTime: expected %v, got %v", tc.expect.DefaultTime, result.DefaultTime)
}
if tc.expect.PtrTime != nil {
assert.NotNil(t, result.PtrTime)
if result.PtrTime != nil {
expectedPtr := time.Date(2023, 2, 20, 0, 0, 0, 0, time.UTC)
assert.True(t, expectedPtr.Equal(*result.PtrTime),
"PtrTime: expected %v, got %v", expectedPtr, *result.PtrTime)
}
}
})
}
}
================================================
FILE: binder.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"encoding"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"time"
)
/**
Following functions provide handful of methods for binding to Go native types from request query or path parameters.
* QueryParamsBinder(c) - binds query parameters (source URL)
* PathValuesBinder(c) - binds path parameters (source URL)
* FormFieldBinder(c) - binds form fields (source URL + body)
Example:
```go
var length int64
err := echo.QueryParamsBinder(c).Int64("length", &length).BindError()
```
For every supported type there are following methods:
* ("param", &destination) - if parameter value exists then binds it to given destination of that type i.e Int64(...).
* Must("param", &destination) - parameter value is required to exist, binds it to given destination of that type i.e MustInt64(...).
* s("param", &destination) - (for slices) if parameter values exists then binds it to given destination of that type i.e Int64s(...).
* Musts("param", &destination) - (for slices) parameter value is required to exist, binds it to given destination of that type i.e MustInt64s(...).
for some slice types `BindWithDelimiter("param", &dest, ",")` supports splitting parameter values before type conversion is done
i.e. URL `/api/search?id=1,2,3&id=1` can be bind to `[]int64{1,2,3,1}`
`FailFast` flags binder to stop binding after first bind error during binder call chain. Enabled by default.
`BindError()` returns first bind error from binder and resets errors in binder. Useful along with `FailFast()` method
to do binding and returns on first problem
`BindErrors()` returns all bind errors from binder and resets errors in binder.
Types that are supported:
* bool
* float32
* float64
* int
* int8
* int16
* int32
* int64
* uint
* uint8/byte (does not support `bytes()`. Use BindUnmarshaler/CustomFunc to convert value from base64 etc to []byte{})
* uint16
* uint32
* uint64
* string
* time
* duration
* BindUnmarshaler() interface
* TextUnmarshaler() interface
* JSONUnmarshaler() interface
* UnixTime() - converts unix time (integer) to time.Time
* UnixTimeMilli() - converts unix time with millisecond precision (integer) to time.Time
* UnixTimeNano() - converts unix time with nanosecond precision (integer) to time.Time
* CustomFunc() - callback function for your custom conversion logic. Signature `func(values []string) []error`
*/
// BindingError represents an error that occurred while binding request data.
type BindingError struct {
// Field is the field name where value binding failed
Field string `json:"field"`
*HTTPError
// Values of parameter that failed to bind.
Values []string `json:"-"`
}
// NewBindingError creates new instance of binding error
func NewBindingError(sourceParam string, values []string, message string, err error) error {
return &BindingError{
Field: sourceParam,
Values: values,
HTTPError: &HTTPError{Code: http.StatusBadRequest, Message: message, err: err},
}
}
// Error returns error message
func (be *BindingError) Error() string {
return fmt.Sprintf("%s, field=%s", be.HTTPError.Error(), be.Field)
}
// ValueBinder provides utility methods for binding query or path parameter to various Go built-in types
type ValueBinder struct {
// ValueFunc is used to get single parameter (first) value from request
ValueFunc func(sourceParam string) string
// ValuesFunc is used to get all values for parameter from request. i.e. `/api/search?ids=1&ids=2`
ValuesFunc func(sourceParam string) []string
// ErrorFunc is used to create errors. Allows you to use your own error type, that for example marshals to your specific json response
ErrorFunc func(sourceParam string, values []string, message string, internalError error) error
errors []error
// failFast is flag for binding methods to return without attempting to bind when previous binding already failed
failFast bool
}
// QueryParamsBinder creates query parameter value binder
func QueryParamsBinder(c *Context) *ValueBinder {
return &ValueBinder{
failFast: true,
ValueFunc: c.QueryParam,
ValuesFunc: func(sourceParam string) []string {
values, ok := c.QueryParams()[sourceParam]
if !ok {
return nil
}
return values
},
ErrorFunc: NewBindingError,
}
}
// PathValuesBinder creates path parameter value binder
func PathValuesBinder(c *Context) *ValueBinder {
return &ValueBinder{
failFast: true,
ValueFunc: c.Param,
ValuesFunc: func(sourceParam string) []string {
// path parameter should not have multiple values so getting values does not make sense but lets not error out here
value := c.Param(sourceParam)
if value == "" {
return nil
}
return []string{value}
},
ErrorFunc: NewBindingError,
}
}
// FormFieldBinder creates form field value binder
// For all requests, FormFieldBinder parses the raw query from the URL and uses query params as form fields
//
// For POST, PUT, and PATCH requests, it also reads the request body, parses it
// as a form and uses query params as form fields. Request body parameters take precedence over URL query
// string values in r.Form.
//
// NB: when binding forms take note that this implementation uses standard library form parsing
// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
// See https://golang.org/pkg/net/http/#Request.ParseForm
func FormFieldBinder(c *Context) *ValueBinder {
vb := &ValueBinder{
failFast: true,
ValueFunc: func(sourceParam string) string {
return c.Request().FormValue(sourceParam)
},
ErrorFunc: NewBindingError,
}
vb.ValuesFunc = func(sourceParam string) []string {
if c.Request().Form == nil {
// this is same as `Request().FormValue()` does internally
_, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory)
}
values, ok := c.Request().Form[sourceParam]
if !ok {
return nil
}
return values
}
return vb
}
// FailFast set internal flag to indicate if binding methods will return early (without binding) when previous bind failed
// NB: call this method before any other binding methods as it modifies binding methods behaviour
func (b *ValueBinder) FailFast(value bool) *ValueBinder {
b.failFast = value
return b
}
func (b *ValueBinder) setError(err error) {
if b.errors == nil {
b.errors = []error{err}
return
}
b.errors = append(b.errors, err)
}
// BindError returns first seen bind error and resets/empties binder errors for further calls
func (b *ValueBinder) BindError() error {
if b.errors == nil {
return nil
}
err := b.errors[0]
b.errors = nil // reset errors so next chain will start from zero
return err
}
// BindErrors returns all bind errors and resets/empties binder errors for further calls
func (b *ValueBinder) BindErrors() []error {
if b.errors == nil {
return nil
}
errors := b.errors
b.errors = nil // reset errors so next chain will start from zero
return errors
}
// CustomFunc binds parameter values with Func. Func is called only when parameter values exist.
func (b *ValueBinder) CustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder {
return b.customFunc(sourceParam, customFunc, false)
}
// MustCustomFunc requires parameter values to exist to bind with Func. Returns error when value does not exist.
func (b *ValueBinder) MustCustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder {
return b.customFunc(sourceParam, customFunc, true)
}
func (b *ValueBinder) customFunc(sourceParam string, customFunc func(values []string) []error, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
values := b.ValuesFunc(sourceParam)
if len(values) == 0 {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
}
return b
}
if errs := customFunc(values); errs != nil {
b.errors = append(b.errors, errs...)
}
return b
}
// String binds parameter to string variable
func (b *ValueBinder) String(sourceParam string, dest *string) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
value := b.ValueFunc(sourceParam)
if value == "" {
return b
}
*dest = value
return b
}
// MustString requires parameter value to exist to bind to string variable. Returns error when value does not exist
func (b *ValueBinder) MustString(sourceParam string, dest *string) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
value := b.ValueFunc(sourceParam)
if value == "" {
b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
return b
}
*dest = value
return b
}
// Strings binds parameter values to slice of string
func (b *ValueBinder) Strings(sourceParam string, dest *[]string) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
value := b.ValuesFunc(sourceParam)
if value == nil {
return b
}
*dest = value
return b
}
// MustStrings requires parameter values to exist to bind to slice of string variables. Returns error when value does not exist
func (b *ValueBinder) MustStrings(sourceParam string, dest *[]string) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
value := b.ValuesFunc(sourceParam)
if value == nil {
b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
return b
}
*dest = value
return b
}
// BindUnmarshaler binds parameter to destination implementing BindUnmarshaler interface
func (b *ValueBinder) BindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
tmp := b.ValueFunc(sourceParam)
if tmp == "" {
return b
}
if err := dest.UnmarshalParam(tmp); err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to BindUnmarshaler interface", err))
}
return b
}
// MustBindUnmarshaler requires parameter value to exist to bind to destination implementing BindUnmarshaler interface.
// Returns error when value does not exist
func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
value := b.ValueFunc(sourceParam)
if value == "" {
b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
return b
}
if err := dest.UnmarshalParam(value); err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to BindUnmarshaler interface", err))
}
return b
}
// JSONUnmarshaler binds parameter to destination implementing json.Unmarshaler interface
func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
tmp := b.ValueFunc(sourceParam)
if tmp == "" {
return b
}
if err := dest.UnmarshalJSON([]byte(tmp)); err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err))
}
return b
}
// MustJSONUnmarshaler requires parameter value to exist to bind to destination implementing json.Unmarshaler interface.
// Returns error when value does not exist
func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
tmp := b.ValueFunc(sourceParam)
if tmp == "" {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil))
return b
}
if err := dest.UnmarshalJSON([]byte(tmp)); err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err))
}
return b
}
// TextUnmarshaler binds parameter to destination implementing encoding.TextUnmarshaler interface
func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
tmp := b.ValueFunc(sourceParam)
if tmp == "" {
return b
}
if err := dest.UnmarshalText([]byte(tmp)); err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err))
}
return b
}
// MustTextUnmarshaler requires parameter value to exist to bind to destination implementing encoding.TextUnmarshaler interface.
// Returns error when value does not exist
func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
tmp := b.ValueFunc(sourceParam)
if tmp == "" {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil))
return b
}
if err := dest.UnmarshalText([]byte(tmp)); err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err))
}
return b
}
// BindWithDelimiter binds parameter to destination by suitable conversion function.
// Delimiter is used before conversion to split parameter value to separate values
func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest any, delimiter string) *ValueBinder {
return b.bindWithDelimiter(sourceParam, dest, delimiter, false)
}
// MustBindWithDelimiter requires parameter value to exist to bind destination by suitable conversion function.
// Delimiter is used before conversion to split parameter value to separate values
func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest any, delimiter string) *ValueBinder {
return b.bindWithDelimiter(sourceParam, dest, delimiter, true)
}
func (b *ValueBinder) bindWithDelimiter(sourceParam string, dest any, delimiter string, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
values := b.ValuesFunc(sourceParam)
if len(values) == 0 {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
}
return b
}
tmpValues := make([]string, 0, len(values))
for _, v := range values {
tmpValues = append(tmpValues, strings.Split(v, delimiter)...)
}
switch d := dest.(type) {
case *[]string:
*d = tmpValues
return b
case *[]bool:
return b.bools(sourceParam, tmpValues, d)
case *[]int64, *[]int32, *[]int16, *[]int8, *[]int:
return b.ints(sourceParam, tmpValues, d)
case *[]uint64, *[]uint32, *[]uint16, *[]uint8, *[]uint: // *[]byte is same as *[]uint8
return b.uints(sourceParam, tmpValues, d)
case *[]float64, *[]float32:
return b.floats(sourceParam, tmpValues, d)
case *[]time.Duration:
return b.durations(sourceParam, tmpValues, d)
default:
// support only cases when destination is slice
// does not support time.Time as it needs argument (layout) for parsing or BindUnmarshaler
b.setError(b.ErrorFunc(sourceParam, []string{}, "unsupported bind type", nil))
return b
}
}
// Int64 binds parameter to int64 variable
func (b *ValueBinder) Int64(sourceParam string, dest *int64) *ValueBinder {
return b.intValue(sourceParam, dest, 64, false)
}
// MustInt64 requires parameter value to exist to bind to int64 variable. Returns error when value does not exist
func (b *ValueBinder) MustInt64(sourceParam string, dest *int64) *ValueBinder {
return b.intValue(sourceParam, dest, 64, true)
}
// Int32 binds parameter to int32 variable
func (b *ValueBinder) Int32(sourceParam string, dest *int32) *ValueBinder {
return b.intValue(sourceParam, dest, 32, false)
}
// MustInt32 requires parameter value to exist to bind to int32 variable. Returns error when value does not exist
func (b *ValueBinder) MustInt32(sourceParam string, dest *int32) *ValueBinder {
return b.intValue(sourceParam, dest, 32, true)
}
// Int16 binds parameter to int16 variable
func (b *ValueBinder) Int16(sourceParam string, dest *int16) *ValueBinder {
return b.intValue(sourceParam, dest, 16, false)
}
// MustInt16 requires parameter value to exist to bind to int16 variable. Returns error when value does not exist
func (b *ValueBinder) MustInt16(sourceParam string, dest *int16) *ValueBinder {
return b.intValue(sourceParam, dest, 16, true)
}
// Int8 binds parameter to int8 variable
func (b *ValueBinder) Int8(sourceParam string, dest *int8) *ValueBinder {
return b.intValue(sourceParam, dest, 8, false)
}
// MustInt8 requires parameter value to exist to bind to int8 variable. Returns error when value does not exist
func (b *ValueBinder) MustInt8(sourceParam string, dest *int8) *ValueBinder {
return b.intValue(sourceParam, dest, 8, true)
}
// Int binds parameter to int variable
func (b *ValueBinder) Int(sourceParam string, dest *int) *ValueBinder {
return b.intValue(sourceParam, dest, 0, false)
}
// MustInt requires parameter value to exist to bind to int variable. Returns error when value does not exist
func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder {
return b.intValue(sourceParam, dest, 0, true)
}
func (b *ValueBinder) intValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
value := b.ValueFunc(sourceParam)
if value == "" {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
}
return b
}
return b.int(sourceParam, value, dest, bitSize)
}
func (b *ValueBinder) int(sourceParam string, value string, dest any, bitSize int) *ValueBinder {
n, err := strconv.ParseInt(value, 10, bitSize)
if err != nil {
if bitSize == 0 {
b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to int", err))
} else {
b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to int%v", bitSize), err))
}
return b
}
switch d := dest.(type) {
case *int64:
*d = n
case *int32:
*d = int32(n) // #nosec G115
case *int16:
*d = int16(n) // #nosec G115
case *int8:
*d = int8(n) // #nosec G115
case *int:
*d = int(n)
}
return b
}
func (b *ValueBinder) intsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
values := b.ValuesFunc(sourceParam)
if len(values) == 0 {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, values, "required field value is empty", nil))
}
return b
}
return b.ints(sourceParam, values, dest)
}
func (b *ValueBinder) ints(sourceParam string, values []string, dest any) *ValueBinder {
switch d := dest.(type) {
case *[]int64:
tmp := make([]int64, len(values))
for i, v := range values {
b.int(sourceParam, v, &tmp[i], 64)
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*d = tmp
}
case *[]int32:
tmp := make([]int32, len(values))
for i, v := range values {
b.int(sourceParam, v, &tmp[i], 32)
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*d = tmp
}
case *[]int16:
tmp := make([]int16, len(values))
for i, v := range values {
b.int(sourceParam, v, &tmp[i], 16)
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*d = tmp
}
case *[]int8:
tmp := make([]int8, len(values))
for i, v := range values {
b.int(sourceParam, v, &tmp[i], 8)
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*d = tmp
}
case *[]int:
tmp := make([]int, len(values))
for i, v := range values {
b.int(sourceParam, v, &tmp[i], 0)
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*d = tmp
}
}
return b
}
// Int64s binds parameter to slice of int64
func (b *ValueBinder) Int64s(sourceParam string, dest *[]int64) *ValueBinder {
return b.intsValue(sourceParam, dest, false)
}
// MustInt64s requires parameter value to exist to bind to int64 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustInt64s(sourceParam string, dest *[]int64) *ValueBinder {
return b.intsValue(sourceParam, dest, true)
}
// Int32s binds parameter to slice of int32
func (b *ValueBinder) Int32s(sourceParam string, dest *[]int32) *ValueBinder {
return b.intsValue(sourceParam, dest, false)
}
// MustInt32s requires parameter value to exist to bind to int32 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustInt32s(sourceParam string, dest *[]int32) *ValueBinder {
return b.intsValue(sourceParam, dest, true)
}
// Int16s binds parameter to slice of int16
func (b *ValueBinder) Int16s(sourceParam string, dest *[]int16) *ValueBinder {
return b.intsValue(sourceParam, dest, false)
}
// MustInt16s requires parameter value to exist to bind to int16 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustInt16s(sourceParam string, dest *[]int16) *ValueBinder {
return b.intsValue(sourceParam, dest, true)
}
// Int8s binds parameter to slice of int8
func (b *ValueBinder) Int8s(sourceParam string, dest *[]int8) *ValueBinder {
return b.intsValue(sourceParam, dest, false)
}
// MustInt8s requires parameter value to exist to bind to int8 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustInt8s(sourceParam string, dest *[]int8) *ValueBinder {
return b.intsValue(sourceParam, dest, true)
}
// Ints binds parameter to slice of int
func (b *ValueBinder) Ints(sourceParam string, dest *[]int) *ValueBinder {
return b.intsValue(sourceParam, dest, false)
}
// MustInts requires parameter value to exist to bind to int slice variable. Returns error when value does not exist
func (b *ValueBinder) MustInts(sourceParam string, dest *[]int) *ValueBinder {
return b.intsValue(sourceParam, dest, true)
}
// Uint64 binds parameter to uint64 variable
func (b *ValueBinder) Uint64(sourceParam string, dest *uint64) *ValueBinder {
return b.uintValue(sourceParam, dest, 64, false)
}
// MustUint64 requires parameter value to exist to bind to uint64 variable. Returns error when value does not exist
func (b *ValueBinder) MustUint64(sourceParam string, dest *uint64) *ValueBinder {
return b.uintValue(sourceParam, dest, 64, true)
}
// Uint32 binds parameter to uint32 variable
func (b *ValueBinder) Uint32(sourceParam string, dest *uint32) *ValueBinder {
return b.uintValue(sourceParam, dest, 32, false)
}
// MustUint32 requires parameter value to exist to bind to uint32 variable. Returns error when value does not exist
func (b *ValueBinder) MustUint32(sourceParam string, dest *uint32) *ValueBinder {
return b.uintValue(sourceParam, dest, 32, true)
}
// Uint16 binds parameter to uint16 variable
func (b *ValueBinder) Uint16(sourceParam string, dest *uint16) *ValueBinder {
return b.uintValue(sourceParam, dest, 16, false)
}
// MustUint16 requires parameter value to exist to bind to uint16 variable. Returns error when value does not exist
func (b *ValueBinder) MustUint16(sourceParam string, dest *uint16) *ValueBinder {
return b.uintValue(sourceParam, dest, 16, true)
}
// Uint8 binds parameter to uint8 variable
func (b *ValueBinder) Uint8(sourceParam string, dest *uint8) *ValueBinder {
return b.uintValue(sourceParam, dest, 8, false)
}
// MustUint8 requires parameter value to exist to bind to uint8 variable. Returns error when value does not exist
func (b *ValueBinder) MustUint8(sourceParam string, dest *uint8) *ValueBinder {
return b.uintValue(sourceParam, dest, 8, true)
}
// Byte binds parameter to byte variable
func (b *ValueBinder) Byte(sourceParam string, dest *byte) *ValueBinder {
return b.uintValue(sourceParam, dest, 8, false)
}
// MustByte requires parameter value to exist to bind to byte variable. Returns error when value does not exist
func (b *ValueBinder) MustByte(sourceParam string, dest *byte) *ValueBinder {
return b.uintValue(sourceParam, dest, 8, true)
}
// Uint binds parameter to uint variable
func (b *ValueBinder) Uint(sourceParam string, dest *uint) *ValueBinder {
return b.uintValue(sourceParam, dest, 0, false)
}
// MustUint requires parameter value to exist to bind to uint variable. Returns error when value does not exist
func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder {
return b.uintValue(sourceParam, dest, 0, true)
}
func (b *ValueBinder) uintValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
value := b.ValueFunc(sourceParam)
if value == "" {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
}
return b
}
return b.uint(sourceParam, value, dest, bitSize)
}
func (b *ValueBinder) uint(sourceParam string, value string, dest any, bitSize int) *ValueBinder {
n, err := strconv.ParseUint(value, 10, bitSize)
if err != nil {
if bitSize == 0 {
b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to uint", err))
} else {
b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to uint%v", bitSize), err))
}
return b
}
switch d := dest.(type) {
case *uint64:
*d = n
case *uint32:
*d = uint32(n) // #nosec G115
case *uint16:
*d = uint16(n) // #nosec G115
case *uint8: // byte is alias to uint8
*d = uint8(n) // #nosec G115
case *uint:
*d = uint(n) // #nosec G115
}
return b
}
func (b *ValueBinder) uintsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
values := b.ValuesFunc(sourceParam)
if len(values) == 0 {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, values, "required field value is empty", nil))
}
return b
}
return b.uints(sourceParam, values, dest)
}
func (b *ValueBinder) uints(sourceParam string, values []string, dest any) *ValueBinder {
switch d := dest.(type) {
case *[]uint64:
tmp := make([]uint64, len(values))
for i, v := range values {
b.uint(sourceParam, v, &tmp[i], 64)
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*d = tmp
}
case *[]uint32:
tmp := make([]uint32, len(values))
for i, v := range values {
b.uint(sourceParam, v, &tmp[i], 32)
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*d = tmp
}
case *[]uint16:
tmp := make([]uint16, len(values))
for i, v := range values {
b.uint(sourceParam, v, &tmp[i], 16)
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*d = tmp
}
case *[]uint8: // byte is alias to uint8
tmp := make([]uint8, len(values))
for i, v := range values {
b.uint(sourceParam, v, &tmp[i], 8)
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*d = tmp
}
case *[]uint:
tmp := make([]uint, len(values))
for i, v := range values {
b.uint(sourceParam, v, &tmp[i], 0)
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*d = tmp
}
}
return b
}
// Uint64s binds parameter to slice of uint64
func (b *ValueBinder) Uint64s(sourceParam string, dest *[]uint64) *ValueBinder {
return b.uintsValue(sourceParam, dest, false)
}
// MustUint64s requires parameter value to exist to bind to uint64 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustUint64s(sourceParam string, dest *[]uint64) *ValueBinder {
return b.uintsValue(sourceParam, dest, true)
}
// Uint32s binds parameter to slice of uint32
func (b *ValueBinder) Uint32s(sourceParam string, dest *[]uint32) *ValueBinder {
return b.uintsValue(sourceParam, dest, false)
}
// MustUint32s requires parameter value to exist to bind to uint32 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustUint32s(sourceParam string, dest *[]uint32) *ValueBinder {
return b.uintsValue(sourceParam, dest, true)
}
// Uint16s binds parameter to slice of uint16
func (b *ValueBinder) Uint16s(sourceParam string, dest *[]uint16) *ValueBinder {
return b.uintsValue(sourceParam, dest, false)
}
// MustUint16s requires parameter value to exist to bind to uint16 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustUint16s(sourceParam string, dest *[]uint16) *ValueBinder {
return b.uintsValue(sourceParam, dest, true)
}
// Uint8s binds parameter to slice of uint8
func (b *ValueBinder) Uint8s(sourceParam string, dest *[]uint8) *ValueBinder {
return b.uintsValue(sourceParam, dest, false)
}
// MustUint8s requires parameter value to exist to bind to uint8 slice variable. Returns error when value does not exist
func (b *ValueBinder) MustUint8s(sourceParam string, dest *[]uint8) *ValueBinder {
return b.uintsValue(sourceParam, dest, true)
}
// Uints binds parameter to slice of uint
func (b *ValueBinder) Uints(sourceParam string, dest *[]uint) *ValueBinder {
return b.uintsValue(sourceParam, dest, false)
}
// MustUints requires parameter value to exist to bind to uint slice variable. Returns error when value does not exist
func (b *ValueBinder) MustUints(sourceParam string, dest *[]uint) *ValueBinder {
return b.uintsValue(sourceParam, dest, true)
}
// Bool binds parameter to bool variable
func (b *ValueBinder) Bool(sourceParam string, dest *bool) *ValueBinder {
return b.boolValue(sourceParam, dest, false)
}
// MustBool requires parameter value to exist to bind to bool variable. Returns error when value does not exist
func (b *ValueBinder) MustBool(sourceParam string, dest *bool) *ValueBinder {
return b.boolValue(sourceParam, dest, true)
}
func (b *ValueBinder) boolValue(sourceParam string, dest *bool, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
value := b.ValueFunc(sourceParam)
if value == "" {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
}
return b
}
return b.bool(sourceParam, value, dest)
}
func (b *ValueBinder) bool(sourceParam string, value string, dest *bool) *ValueBinder {
n, err := strconv.ParseBool(value)
if err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to bool", err))
return b
}
*dest = n
return b
}
func (b *ValueBinder) boolsValue(sourceParam string, dest *[]bool, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
values := b.ValuesFunc(sourceParam)
if len(values) == 0 {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
}
return b
}
return b.bools(sourceParam, values, dest)
}
func (b *ValueBinder) bools(sourceParam string, values []string, dest *[]bool) *ValueBinder {
tmp := make([]bool, len(values))
for i, v := range values {
b.bool(sourceParam, v, &tmp[i])
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*dest = tmp
}
return b
}
// Bools binds parameter values to slice of bool variables
func (b *ValueBinder) Bools(sourceParam string, dest *[]bool) *ValueBinder {
return b.boolsValue(sourceParam, dest, false)
}
// MustBools requires parameter values to exist to bind to slice of bool variables. Returns error when values does not exist
func (b *ValueBinder) MustBools(sourceParam string, dest *[]bool) *ValueBinder {
return b.boolsValue(sourceParam, dest, true)
}
// Float64 binds parameter to float64 variable
func (b *ValueBinder) Float64(sourceParam string, dest *float64) *ValueBinder {
return b.floatValue(sourceParam, dest, 64, false)
}
// MustFloat64 requires parameter value to exist to bind to float64 variable. Returns error when value does not exist
func (b *ValueBinder) MustFloat64(sourceParam string, dest *float64) *ValueBinder {
return b.floatValue(sourceParam, dest, 64, true)
}
// Float32 binds parameter to float32 variable
func (b *ValueBinder) Float32(sourceParam string, dest *float32) *ValueBinder {
return b.floatValue(sourceParam, dest, 32, false)
}
// MustFloat32 requires parameter value to exist to bind to float32 variable. Returns error when value does not exist
func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinder {
return b.floatValue(sourceParam, dest, 32, true)
}
func (b *ValueBinder) floatValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
value := b.ValueFunc(sourceParam)
if value == "" {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
}
return b
}
return b.float(sourceParam, value, dest, bitSize)
}
func (b *ValueBinder) float(sourceParam string, value string, dest any, bitSize int) *ValueBinder {
n, err := strconv.ParseFloat(value, bitSize)
if err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to float%v", bitSize), err))
return b
}
switch d := dest.(type) {
case *float64:
*d = n
case *float32:
*d = float32(n)
}
return b
}
func (b *ValueBinder) floatsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
values := b.ValuesFunc(sourceParam)
if len(values) == 0 {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
}
return b
}
return b.floats(sourceParam, values, dest)
}
func (b *ValueBinder) floats(sourceParam string, values []string, dest any) *ValueBinder {
switch d := dest.(type) {
case *[]float64:
tmp := make([]float64, len(values))
for i, v := range values {
b.float(sourceParam, v, &tmp[i], 64)
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*d = tmp
}
case *[]float32:
tmp := make([]float32, len(values))
for i, v := range values {
b.float(sourceParam, v, &tmp[i], 32)
if b.failFast && b.errors != nil {
return b
}
}
if b.errors == nil {
*d = tmp
}
}
return b
}
// Float64s binds parameter values to slice of float64 variables
func (b *ValueBinder) Float64s(sourceParam string, dest *[]float64) *ValueBinder {
return b.floatsValue(sourceParam, dest, false)
}
// MustFloat64s requires parameter values to exist to bind to slice of float64 variables. Returns error when values does not exist
func (b *ValueBinder) MustFloat64s(sourceParam string, dest *[]float64) *ValueBinder {
return b.floatsValue(sourceParam, dest, true)
}
// Float32s binds parameter values to slice of float32 variables
func (b *ValueBinder) Float32s(sourceParam string, dest *[]float32) *ValueBinder {
return b.floatsValue(sourceParam, dest, false)
}
// MustFloat32s requires parameter values to exist to bind to slice of float32 variables. Returns error when values does not exist
func (b *ValueBinder) MustFloat32s(sourceParam string, dest *[]float32) *ValueBinder {
return b.floatsValue(sourceParam, dest, true)
}
// Time binds parameter to time.Time variable
func (b *ValueBinder) Time(sourceParam string, dest *time.Time, layout string) *ValueBinder {
return b.time(sourceParam, dest, layout, false)
}
// MustTime requires parameter value to exist to bind to time.Time variable. Returns error when value does not exist
func (b *ValueBinder) MustTime(sourceParam string, dest *time.Time, layout string) *ValueBinder {
return b.time(sourceParam, dest, layout, true)
}
func (b *ValueBinder) time(sourceParam string, dest *time.Time, layout string, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
value := b.ValueFunc(sourceParam)
if value == "" {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
}
return b
}
t, err := time.Parse(layout, value)
if err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Time", err))
return b
}
*dest = t
return b
}
// Times binds parameter values to slice of time.Time variables
func (b *ValueBinder) Times(sourceParam string, dest *[]time.Time, layout string) *ValueBinder {
return b.times(sourceParam, dest, layout, false)
}
// MustTimes requires parameter values to exist to bind to slice of time.Time variables. Returns error when values does not exist
func (b *ValueBinder) MustTimes(sourceParam string, dest *[]time.Time, layout string) *ValueBinder {
return b.times(sourceParam, dest, layout, true)
}
func (b *ValueBinder) times(sourceParam string, dest *[]time.Time, layout string, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
values := b.ValuesFunc(sourceParam)
if len(values) == 0 {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
}
return b
}
tmp := make([]time.Time, len(values))
for i, v := range values {
t, err := time.Parse(layout, v)
if err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{v}, "failed to bind field value to Time", err))
if b.failFast {
return b
}
continue
}
tmp[i] = t
}
if b.errors == nil {
*dest = tmp
}
return b
}
// Duration binds parameter to time.Duration variable
func (b *ValueBinder) Duration(sourceParam string, dest *time.Duration) *ValueBinder {
return b.duration(sourceParam, dest, false)
}
// MustDuration requires parameter value to exist to bind to time.Duration variable. Returns error when value does not exist
func (b *ValueBinder) MustDuration(sourceParam string, dest *time.Duration) *ValueBinder {
return b.duration(sourceParam, dest, true)
}
func (b *ValueBinder) duration(sourceParam string, dest *time.Duration, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
value := b.ValueFunc(sourceParam)
if value == "" {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
}
return b
}
t, err := time.ParseDuration(value)
if err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Duration", err))
return b
}
*dest = t
return b
}
// Durations binds parameter values to slice of time.Duration variables
func (b *ValueBinder) Durations(sourceParam string, dest *[]time.Duration) *ValueBinder {
return b.durationsValue(sourceParam, dest, false)
}
// MustDurations requires parameter values to exist to bind to slice of time.Duration variables. Returns error when values does not exist
func (b *ValueBinder) MustDurations(sourceParam string, dest *[]time.Duration) *ValueBinder {
return b.durationsValue(sourceParam, dest, true)
}
func (b *ValueBinder) durationsValue(sourceParam string, dest *[]time.Duration, valueMustExist bool) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
values := b.ValuesFunc(sourceParam)
if len(values) == 0 {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
}
return b
}
return b.durations(sourceParam, values, dest)
}
func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]time.Duration) *ValueBinder {
tmp := make([]time.Duration, len(values))
for i, v := range values {
t, err := time.ParseDuration(v)
if err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{v}, "failed to bind field value to Duration", err))
if b.failFast {
return b
}
continue
}
tmp[i] = t
}
if b.errors == nil {
*dest = tmp
}
return b
}
// UnixTime binds parameter to time.Time variable (in local Time corresponding to the given Unix time).
//
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
//
// Note:
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, false, time.Second)
}
// MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding
// to the given Unix time). Returns error when value does not exist.
//
// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
//
// Note:
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, true, time.Second)
}
// UnixTimeMilli binds parameter to time.Time variable (in local time corresponding to the given Unix time in millisecond precision).
//
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
//
// Note:
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, false, time.Millisecond)
}
// MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding
// to the given Unix time in millisecond precision). Returns error when value does not exist.
//
// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
//
// Note:
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, true, time.Millisecond)
}
// UnixTimeNano binds parameter to time.Time variable (in local time corresponding to the given Unix time in nanosecond precision).
//
// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00
// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00
// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
//
// Note:
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, false, time.Nanosecond)
}
// MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding
// to the given Unix time value in nano second precision). Returns error when value does not exist.
//
// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00
// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00
// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
//
// Note:
// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
return b.unixTime(sourceParam, dest, true, time.Nanosecond)
}
func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, precision time.Duration) *ValueBinder {
if b.failFast && b.errors != nil {
return b
}
value := b.ValueFunc(sourceParam)
if value == "" {
if valueMustExist {
b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
}
return b
}
n, err := strconv.ParseInt(value, 10, 64)
if err != nil {
b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Time", err))
return b
}
switch precision {
case time.Second:
*dest = time.Unix(n, 0)
case time.Millisecond:
*dest = time.UnixMilli(n)
case time.Nanosecond:
*dest = time.Unix(0, n)
}
return b
}
================================================
FILE: binder_external_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
// run tests as external package to get real feel for API
package echo_test
import (
"encoding/base64"
"fmt"
"log"
"net/http"
"net/http/httptest"
"github.com/labstack/echo/v5"
)
func ExampleValueBinder_BindErrors() {
// example route function that binds query params to different destinations and returns all bind errors in one go
routeFunc := func(c *echo.Context) error {
var opts struct {
IDs []int64
Active bool
}
length := int64(50) // default length is 50
b := echo.QueryParamsBinder(c)
errs := b.Int64("length", &length).
Int64s("ids", &opts.IDs).
Bool("active", &opts.Active).
BindErrors() // returns all errors
if errs != nil {
for _, err := range errs {
bErr := err.(*echo.BindingError)
log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values)
}
return fmt.Errorf("%v fields failed to bind", len(errs))
}
fmt.Printf("active = %v, length = %v, ids = %v", opts.Active, length, opts.IDs)
return c.JSON(http.StatusOK, opts)
}
e := echo.New()
c := e.NewContext(
httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil),
httptest.NewRecorder(),
)
_ = routeFunc(c)
// Output: active = true, length = 25, ids = [1 2 3]
}
func ExampleValueBinder_BindError() {
// example route function that binds query params to different destinations and stops binding on first bind error
failFastRouteFunc := func(c *echo.Context) error {
var opts struct {
IDs []int64
Active bool
}
length := int64(50) // default length is 50
// create binder that stops binding at first error
b := echo.QueryParamsBinder(c)
err := b.Int64("length", &length).
Int64s("ids", &opts.IDs).
Bool("active", &opts.Active).
BindError() // returns first binding error
if err != nil {
bErr := err.(*echo.BindingError)
return fmt.Errorf("my own custom error for field: %s values: %v", bErr.Field, bErr.Values)
}
fmt.Printf("active = %v, length = %v, ids = %v\n", opts.Active, length, opts.IDs)
return c.JSON(http.StatusOK, opts)
}
e := echo.New()
c := e.NewContext(
httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil),
httptest.NewRecorder(),
)
_ = failFastRouteFunc(c)
// Output: active = true, length = 25, ids = [1 2 3]
}
func ExampleValueBinder_CustomFunc() {
// example route function that binds query params using custom function closure
routeFunc := func(c *echo.Context) error {
length := int64(50) // default length is 50
var binary []byte
b := echo.QueryParamsBinder(c)
errs := b.Int64("length", &length).
CustomFunc("base64", func(values []string) []error {
if len(values) == 0 {
return nil
}
decoded, err := base64.URLEncoding.DecodeString(values[0])
if err != nil {
// in this example we use only first param value but url could contain multiple params in reality and
// therefore in theory produce multiple binding errors
return []error{echo.NewBindingError("base64", values[0:1], "failed to decode base64", err)}
}
binary = decoded
return nil
}).
BindErrors() // returns all errors
if errs != nil {
for _, err := range errs {
bErr := err.(*echo.BindingError)
log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values)
}
return fmt.Errorf("%v fields failed to bind", len(errs))
}
fmt.Printf("length = %v, base64 = %s", length, binary)
return c.JSON(http.StatusOK, "ok")
}
e := echo.New()
c := e.NewContext(
httptest.NewRequest(http.MethodGet, "/api/endpoint?length=25&base64=SGVsbG8gV29ybGQ%3D", nil),
httptest.NewRecorder(),
)
_ = routeFunc(c)
// Output: length = 25, base64 = Hello World
}
================================================
FILE: binder_generic.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"encoding"
"encoding/json"
"fmt"
"strconv"
"time"
)
// TimeLayout specifies the format for parsing time values in request parameters.
// It can be a standard Go time layout string or one of the special Unix time layouts.
type TimeLayout string
// TimeOpts is options for parsing time.Time values
type TimeOpts struct {
// Layout specifies the format for parsing time values in request parameters.
// It can be a standard Go time layout string or one of the special Unix time layouts.
//
// Parsing layout defaults to: echo.TimeLayout(time.RFC3339Nano)
// - To convert to custom layout use `echo.TimeLayout("2006-01-02")`
// - To convert unix timestamp (integer) to time.Time use `echo.TimeLayoutUnixTime`
// - To convert unix timestamp in milliseconds to time.Time use `echo.TimeLayoutUnixTimeMilli`
// - To convert unix timestamp in nanoseconds to time.Time use `echo.TimeLayoutUnixTimeNano`
Layout TimeLayout
// ParseInLocation is location used with time.ParseInLocation for layout that do not contain
// timezone information to set output time in given location.
// Defaults to time.UTC
ParseInLocation *time.Location
// ToInLocation is location to which parsed time is converted to after parsing.
// The parsed time will be converted using time.In(ToInLocation).
// Defaults to time.UTC
ToInLocation *time.Location
}
// TimeLayout constants for parsing Unix timestamps in different precisions.
const (
TimeLayoutUnixTime = TimeLayout("UnixTime") // Unix timestamp in seconds
TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli") // Unix timestamp in milliseconds
TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano") // Unix timestamp in nanoseconds
)
// PathParam extracts and parses a path parameter from the context by name.
// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found.
//
// Empty String Handling:
// If the parameter exists but has an empty value, the zero value of type T is returned
// with no error. For example, a path parameter with value "" returns (0, nil) for int types.
// This differs from standard library behavior where parsing empty strings returns errors.
// To treat empty values as errors, validate the result separately or check the raw value.
//
// See ParseValue for supported types and options
func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) {
for _, pv := range c.PathValues() {
if pv.Name == paramName {
v, err := ParseValue[T](pv.Value, opts...)
if err != nil {
return v, NewBindingError(paramName, []string{pv.Value}, "path value", err)
}
return v, nil
}
}
var zero T
return zero, ErrNonExistentKey
}
// PathParamOr extracts and parses a path parameter from the context by name.
// Returns defaultValue if the parameter is not found or has an empty value.
// Returns an error only if parsing fails (e.g., "abc" for int type).
//
// Example:
// id, err := echo.PathParamOr[int](c, "id", 0)
// // If "id" is missing: returns (0, nil)
// // If "id" is "123": returns (123, nil)
// // If "id" is "abc": returns (0, BindingError)
//
// See ParseValue for supported types and options
func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error) {
for _, pv := range c.PathValues() {
if pv.Name == paramName {
v, err := ParseValueOr[T](pv.Value, defaultValue, opts...)
if err != nil {
return v, NewBindingError(paramName, []string{pv.Value}, "path value", err)
}
return v, nil
}
}
return defaultValue, nil
}
// QueryParam extracts and parses a single query parameter from the request by key.
// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found.
//
// Empty String Handling:
// If the parameter exists but has an empty value (?key=), the zero value of type T is returned
// with no error. For example, "?count=" returns (0, nil) for int types.
// This differs from standard library behavior where parsing empty strings returns errors.
// To treat empty values as errors, validate the result separately or check the raw value.
//
// Behavior Summary:
// - Missing key (?other=value): returns (zero, ErrNonExistentKey)
// - Empty value (?key=): returns (zero, nil)
// - Invalid value (?key=abc for int): returns (zero, BindingError)
//
// See ParseValue for supported types and options
func QueryParam[T any](c *Context, key string, opts ...any) (T, error) {
values, ok := c.QueryParams()[key]
if !ok {
var zero T
return zero, ErrNonExistentKey
}
if len(values) == 0 {
var zero T
return zero, nil
}
value := values[0]
v, err := ParseValue[T](value, opts...)
if err != nil {
return v, NewBindingError(key, []string{value}, "query param", err)
}
return v, nil
}
// QueryParamOr extracts and parses a single query parameter from the request by key.
// Returns defaultValue if the parameter is not found or has an empty value.
// Returns an error only if parsing fails (e.g., "abc" for int type).
//
// Example:
// page, err := echo.QueryParamOr[int](c, "page", 1)
// // If "page" is missing: returns (1, nil)
// // If "page" is "5": returns (5, nil)
// // If "page" is "abc": returns (1, BindingError)
//
// See ParseValue for supported types and options
func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) {
values, ok := c.QueryParams()[key]
if !ok {
return defaultValue, nil
}
if len(values) == 0 {
return defaultValue, nil
}
value := values[0]
v, err := ParseValueOr[T](value, defaultValue, opts...)
if err != nil {
return v, NewBindingError(key, []string{value}, "query param", err)
}
return v, nil
}
// QueryParams extracts and parses all values for a query parameter key as a slice.
// It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found.
//
// See ParseValues for supported types and options
func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) {
values, ok := c.QueryParams()[key]
if !ok {
return nil, ErrNonExistentKey
}
result, err := ParseValues[T](values, opts...)
if err != nil {
return nil, NewBindingError(key, values, "query params", err)
}
return result, nil
}
// QueryParamsOr extracts and parses all values for a query parameter key as a slice.
// Returns defaultValue if the parameter is not found.
// Returns an error only if parsing any value fails.
//
// Example:
// ids, err := echo.QueryParamsOr[int](c, "ids", []int{})
// // If "ids" is missing: returns ([], nil)
// // If "ids" is "1&ids=2": returns ([1, 2], nil)
// // If "ids" contains "abc": returns ([], BindingError)
//
// See ParseValues for supported types and options
func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) {
values, ok := c.QueryParams()[key]
if !ok {
return defaultValue, nil
}
result, err := ParseValuesOr[T](values, defaultValue, opts...)
if err != nil {
return nil, NewBindingError(key, values, "query params", err)
}
return result, nil
}
// FormValue extracts and parses a single form value from the request by key.
// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found.
//
// Empty String Handling:
// If the form field exists but has an empty value, the zero value of type T is returned
// with no error. For example, an empty form field returns (0, nil) for int types.
// This differs from standard library behavior where parsing empty strings returns errors.
// To treat empty values as errors, validate the result separately or check the raw value.
//
// See ParseValue for supported types and options
func FormValue[T any](c *Context, key string, opts ...any) (T, error) {
formValues, err := c.FormValues()
if err != nil {
var zero T
return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err)
}
values, ok := formValues[key]
if !ok {
var zero T
return zero, ErrNonExistentKey
}
if len(values) == 0 {
var zero T
return zero, nil
}
value := values[0]
v, err := ParseValue[T](value, opts...)
if err != nil {
return v, NewBindingError(key, []string{value}, "form value", err)
}
return v, nil
}
// FormValueOr extracts and parses a single form value from the request by key.
// Returns defaultValue if the parameter is not found or has an empty value.
// Returns an error only if parsing fails or form parsing errors occur.
//
// Example:
// limit, err := echo.FormValueOr[int](c, "limit", 100)
// // If "limit" is missing: returns (100, nil)
// // If "limit" is "50": returns (50, nil)
// // If "limit" is "abc": returns (100, BindingError)
//
// See ParseValue for supported types and options
func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) {
formValues, err := c.FormValues()
if err != nil {
var zero T
return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err)
}
values, ok := formValues[key]
if !ok {
return defaultValue, nil
}
if len(values) == 0 {
return defaultValue, nil
}
value := values[0]
v, err := ParseValueOr[T](value, defaultValue, opts...)
if err != nil {
return v, NewBindingError(key, []string{value}, "form value", err)
}
return v, nil
}
// FormValues extracts and parses all values for a form values key as a slice.
// It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found.
//
// See ParseValues for supported types and options
func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) {
formValues, err := c.FormValues()
if err != nil {
return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err)
}
values, ok := formValues[key]
if !ok {
return nil, ErrNonExistentKey
}
result, err := ParseValues[T](values, opts...)
if err != nil {
return nil, NewBindingError(key, values, "form values", err)
}
return result, nil
}
// FormValuesOr extracts and parses all values for a form values key as a slice.
// Returns defaultValue if the parameter is not found.
// Returns an error only if parsing any value fails or form parsing errors occur.
//
// Example:
// tags, err := echo.FormValuesOr[string](c, "tags", []string{})
// // If "tags" is missing: returns ([], nil)
// // If form parsing fails: returns (nil, error)
//
// See ParseValues for supported types and options
func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) {
formValues, err := c.FormValues()
if err != nil {
return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err)
}
values, ok := formValues[key]
if !ok {
return defaultValue, nil
}
result, err := ParseValuesOr[T](values, defaultValue, opts...)
if err != nil {
return nil, NewBindingError(key, values, "form values", err)
}
return result, nil
}
// ParseValues parses value to generic type slice. Same types are supported as ParseValue
// function but the result type is slice instead of scalar value.
//
// See ParseValue for supported types and options
func ParseValues[T any](values []string, opts ...any) ([]T, error) {
var zero []T
return ParseValuesOr(values, zero, opts...)
}
// ParseValuesOr parses value to generic type slice, when value is empty defaultValue is returned.
// Same types are supported as ParseValue function but the result type is slice instead of scalar value.
//
// See ParseValue for supported types and options
func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error) {
if len(values) == 0 {
return defaultValue, nil
}
result := make([]T, 0, len(values))
for _, v := range values {
tmp, err := ParseValue[T](v, opts...)
if err != nil {
return nil, err
}
result = append(result, tmp)
}
return result, nil
}
// ParseValue parses value to generic type
//
// Types that are supported:
// - bool
// - float32
// - float64
// - int
// - int8
// - int16
// - int32
// - int64
// - uint
// - uint8/byte
// - uint16
// - uint32
// - uint64
// - string
// - echo.BindUnmarshaler interface
// - encoding.TextUnmarshaler interface
// - json.Unmarshaler interface
// - time.Duration
// - time.Time use echo.TimeOpts or echo.TimeLayout to set time parsing configuration
func ParseValue[T any](value string, opts ...any) (T, error) {
var zero T
return ParseValueOr(value, zero, opts...)
}
// ParseValueOr parses value to generic type, when value is empty defaultValue is returned.
//
// Types that are supported:
// - bool
// - float32
// - float64
// - int
// - int8
// - int16
// - int32
// - int64
// - uint
// - uint8/byte
// - uint16
// - uint32
// - uint64
// - string
// - echo.BindUnmarshaler interface
// - encoding.TextUnmarshaler interface
// - json.Unmarshaler interface
// - time.Duration
// - time.Time use echo.TimeOpts or echo.TimeLayout to set time parsing configuration
func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error) {
if len(value) == 0 {
return defaultValue, nil
}
var tmp T
if err := bindValue(value, &tmp, opts...); err != nil {
var zero T
return zero, fmt.Errorf("failed to parse value, err: %w", err)
}
return tmp, nil
}
func bindValue(value string, dest any, opts ...any) error {
// NOTE: if this function is ever made public the dest should be checked for nil
// values when dealing with interfaces
if len(opts) > 0 {
if _, isTime := dest.(*time.Time); !isTime {
return fmt.Errorf("options are only supported for time.Time, got %T", dest)
}
}
switch d := dest.(type) {
case *bool:
n, err := strconv.ParseBool(value)
if err != nil {
return err
}
*d = n
case *float32:
n, err := strconv.ParseFloat(value, 32)
if err != nil {
return err
}
*d = float32(n)
case *float64:
n, err := strconv.ParseFloat(value, 64)
if err != nil {
return err
}
*d = n
case *int:
n, err := strconv.ParseInt(value, 10, 0)
if err != nil {
return err
}
*d = int(n)
case *int8:
n, err := strconv.ParseInt(value, 10, 8)
if err != nil {
return err
}
*d = int8(n)
case *int16:
n, err := strconv.ParseInt(value, 10, 16)
if err != nil {
return err
}
*d = int16(n)
case *int32:
n, err := strconv.ParseInt(value, 10, 32)
if err != nil {
return err
}
*d = int32(n)
case *int64:
n, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return err
}
*d = n
case *uint:
n, err := strconv.ParseUint(value, 10, 0)
if err != nil {
return err
}
*d = uint(n)
case *uint8:
n, err := strconv.ParseUint(value, 10, 8)
if err != nil {
return err
}
*d = uint8(n)
case *uint16:
n, err := strconv.ParseUint(value, 10, 16)
if err != nil {
return err
}
*d = uint16(n)
case *uint32:
n, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return err
}
*d = uint32(n)
case *uint64:
n, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return err
}
*d = n
case *string:
*d = value
case *time.Duration:
t, err := time.ParseDuration(value)
if err != nil {
return err
}
*d = t
case *time.Time:
to := TimeOpts{
Layout: TimeLayout(time.RFC3339Nano),
ParseInLocation: time.UTC,
ToInLocation: time.UTC,
}
for _, o := range opts {
switch v := o.(type) {
case TimeOpts:
if v.Layout != "" {
to.Layout = v.Layout
}
if v.ParseInLocation != nil {
to.ParseInLocation = v.ParseInLocation
}
if v.ToInLocation != nil {
to.ToInLocation = v.ToInLocation
}
case TimeLayout:
to.Layout = v
}
}
var t time.Time
var err error
switch to.Layout {
case TimeLayoutUnixTime:
n, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return err
}
t = time.Unix(n, 0)
case TimeLayoutUnixTimeMilli:
n, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return err
}
t = time.UnixMilli(n)
case TimeLayoutUnixTimeNano:
n, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return err
}
t = time.Unix(0, n)
default:
if to.ParseInLocation != nil {
t, err = time.ParseInLocation(string(to.Layout), value, to.ParseInLocation)
} else {
t, err = time.Parse(string(to.Layout), value)
}
if err != nil {
return err
}
}
*d = t.In(to.ToInLocation)
case BindUnmarshaler:
if err := d.UnmarshalParam(value); err != nil {
return err
}
case encoding.TextUnmarshaler:
if err := d.UnmarshalText([]byte(value)); err != nil {
return err
}
case json.Unmarshaler:
if err := d.UnmarshalJSON([]byte(value)); err != nil {
return err
}
default:
return fmt.Errorf("unsupported value type: %T", dest)
}
return nil
}
================================================
FILE: binder_generic_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"cmp"
"encoding/json"
"fmt"
"math"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// TextUnmarshalerType implements encoding.TextUnmarshaler but NOT BindUnmarshaler
type TextUnmarshalerType struct {
Value string
}
func (t *TextUnmarshalerType) UnmarshalText(data []byte) error {
s := string(data)
if s == "invalid" {
return fmt.Errorf("invalid value: %s", s)
}
t.Value = strings.ToUpper(s)
return nil
}
// JSONUnmarshalerType implements json.Unmarshaler but NOT BindUnmarshaler or TextUnmarshaler
type JSONUnmarshalerType struct {
Value string
}
func (j *JSONUnmarshalerType) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &j.Value)
}
func TestPathParam(t *testing.T) {
var testCases = []struct {
name string
givenKey string
givenValue string
expect bool
expectErr string
}{
{
name: "ok",
givenValue: "true",
expect: true,
},
{
name: "nok, non existent key",
givenKey: "missing",
givenValue: "true",
expect: false,
expectErr: ErrNonExistentKey.Error(),
},
{
name: "nok, invalid value",
givenValue: "can_parse_me",
expect: false,
expectErr: `code=400, message=path value, err=failed to parse value, err: strconv.ParseBool: parsing "can_parse_me": invalid syntax, field=key`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := NewContext(nil, nil)
c.SetPathValues(PathValues{{
Name: cmp.Or(tc.givenKey, "key"),
Value: tc.givenValue,
}})
v, err := PathParam[bool](c, "key")
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestPathParam_UnsupportedType(t *testing.T) {
c := NewContext(nil, nil)
c.SetPathValues(PathValues{{Name: "key", Value: "true"}})
v, err := PathParam[[]bool](c, "key")
expectErr := "code=400, message=path value, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
assert.EqualError(t, err, expectErr)
assert.Equal(t, []bool(nil), v)
}
func TestQueryParam(t *testing.T) {
var testCases = []struct {
name string
givenURL string
expect bool
expectErr string
}{
{
name: "ok",
givenURL: "/?key=true",
expect: true,
},
{
name: "nok, non existent key",
givenURL: "/?different=true",
expect: false,
expectErr: ErrNonExistentKey.Error(),
},
{
name: "nok, invalid value",
givenURL: "/?key=invalidbool",
expect: false,
expectErr: `code=400, message=query param, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
c := NewContext(req, nil)
v, err := QueryParam[bool](c, "key")
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestQueryParam_UnsupportedType(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
c := NewContext(req, nil)
v, err := QueryParam[[]bool](c, "key")
expectErr := "code=400, message=query param, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
assert.EqualError(t, err, expectErr)
assert.Equal(t, []bool(nil), v)
}
func TestQueryParams(t *testing.T) {
var testCases = []struct {
name string
givenURL string
expect []bool
expectErr string
}{
{
name: "ok",
givenURL: "/?key=true&key=false",
expect: []bool{true, false},
},
{
name: "nok, non existent key",
givenURL: "/?different=true",
expect: []bool(nil),
expectErr: ErrNonExistentKey.Error(),
},
{
name: "nok, invalid value",
givenURL: "/?key=true&key=invalidbool",
expect: []bool(nil),
expectErr: `code=400, message=query params, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
c := NewContext(req, nil)
v, err := QueryParams[bool](c, "key")
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestQueryParams_UnsupportedType(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
c := NewContext(req, nil)
v, err := QueryParams[[]bool](c, "key")
expectErr := "code=400, message=query params, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
assert.EqualError(t, err, expectErr)
assert.Equal(t, [][]bool(nil), v)
}
func TestFormValue(t *testing.T) {
var testCases = []struct {
name string
givenURL string
expect bool
expectErr string
}{
{
name: "ok",
givenURL: "/?key=true",
expect: true,
},
{
name: "nok, non existent key",
givenURL: "/?different=true",
expect: false,
expectErr: ErrNonExistentKey.Error(),
},
{
name: "nok, invalid value",
givenURL: "/?key=invalidbool",
expect: false,
expectErr: `code=400, message=form value, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
c := NewContext(req, nil)
v, err := FormValue[bool](c, "key")
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestFormValue_UnsupportedType(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
c := NewContext(req, nil)
v, err := FormValue[[]bool](c, "key")
expectErr := "code=400, message=form value, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
assert.EqualError(t, err, expectErr)
assert.Equal(t, []bool(nil), v)
}
func TestFormValues(t *testing.T) {
var testCases = []struct {
name string
givenURL string
expect []bool
expectErr string
}{
{
name: "ok",
givenURL: "/?key=true&key=false",
expect: []bool{true, false},
},
{
name: "nok, non existent key",
givenURL: "/?different=true",
expect: []bool(nil),
expectErr: ErrNonExistentKey.Error(),
},
{
name: "nok, invalid value",
givenURL: "/?key=true&key=invalidbool",
expect: []bool(nil),
expectErr: `code=400, message=form values, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
c := NewContext(req, nil)
v, err := FormValues[bool](c, "key")
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestFormValues_UnsupportedType(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
c := NewContext(req, nil)
v, err := FormValues[[]bool](c, "key")
expectErr := "code=400, message=form values, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
assert.EqualError(t, err, expectErr)
assert.Equal(t, [][]bool(nil), v)
}
func TestParseValue_bool(t *testing.T) {
var testCases = []struct {
name string
when string
expect bool
expectErr error
}{
{
name: "ok, true",
when: "true",
expect: true,
},
{
name: "ok, false",
when: "false",
expect: false,
},
{
name: "ok, 1",
when: "1",
expect: true,
},
{
name: "ok, 0",
when: "0",
expect: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[bool](tc.when)
if tc.expectErr != nil {
assert.ErrorIs(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_float32(t *testing.T) {
var testCases = []struct {
name string
when string
expect float32
expectErr string
}{
{
name: "ok, 123.345",
when: "123.345",
expect: 123.345,
},
{
name: "ok, 0",
when: "0",
expect: 0,
},
{
name: "ok, Inf",
when: "+Inf",
expect: float32(math.Inf(1)),
},
{
name: "ok, Inf",
when: "-Inf",
expect: float32(math.Inf(-1)),
},
{
name: "ok, NaN",
when: "NaN",
expect: float32(math.NaN()),
},
{
name: "ok, invalid value",
when: "X",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseFloat: parsing "X": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[float32](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
if math.IsNaN(float64(tc.expect)) {
if !math.IsNaN(float64(v)) {
t.Fatal("expected NaN but got non NaN")
}
} else {
assert.Equal(t, tc.expect, v)
}
})
}
}
func TestParseValue_float64(t *testing.T) {
var testCases = []struct {
name string
when string
expect float64
expectErr string
}{
{
name: "ok, 123.345",
when: "123.345",
expect: 123.345,
},
{
name: "ok, 0",
when: "0",
expect: 0,
},
{
name: "ok, Inf",
when: "+Inf",
expect: math.Inf(1),
},
{
name: "ok, Inf",
when: "-Inf",
expect: math.Inf(-1),
},
{
name: "ok, NaN",
when: "NaN",
expect: math.NaN(),
},
{
name: "ok, invalid value",
when: "X",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseFloat: parsing "X": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[float64](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
if math.IsNaN(tc.expect) {
if !math.IsNaN(v) {
t.Fatal("expected NaN but got non NaN")
}
} else {
assert.Equal(t, tc.expect, v)
}
})
}
}
func TestParseValue_int(t *testing.T) {
var testCases = []struct {
name string
when string
expect int
expectErr string
}{
{
name: "ok, 0",
when: "0",
expect: 0,
},
{
name: "ok, 1",
when: "1",
expect: 1,
},
{
name: "ok, -1",
when: "-1",
expect: -1,
},
{
name: "ok, max int (64bit)",
when: "9223372036854775807",
expect: 9223372036854775807,
},
{
name: "ok, min int (64bit)",
when: "-9223372036854775808",
expect: -9223372036854775808,
},
{
name: "ok, overflow max int (64bit)",
when: "9223372036854775808",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "9223372036854775808": value out of range`,
},
{
name: "ok, underflow min int (64bit)",
when: "-9223372036854775809",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-9223372036854775809": value out of range`,
},
{
name: "ok, invalid value",
when: "X",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[int](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_uint(t *testing.T) {
var testCases = []struct {
name string
when string
expect uint
expectErr string
}{
{
name: "ok, 0",
when: "0",
expect: 0,
},
{
name: "ok, 1",
when: "1",
expect: 1,
},
{
name: "ok, max uint (64bit)",
when: "18446744073709551615",
expect: 18446744073709551615,
},
{
name: "nok, overflow max uint (64bit)",
when: "18446744073709551616",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "18446744073709551616": value out of range`,
},
{
name: "nok, negative value",
when: "-1",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
},
{
name: "nok, invalid value",
when: "X",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[uint](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_int8(t *testing.T) {
var testCases = []struct {
name string
when string
expect int8
expectErr string
}{
{
name: "ok, 0",
when: "0",
expect: 0,
},
{
name: "ok, 1",
when: "1",
expect: 1,
},
{
name: "ok, -1",
when: "-1",
expect: -1,
},
{
name: "ok, max int8",
when: "127",
expect: 127,
},
{
name: "ok, min int8",
when: "-128",
expect: -128,
},
{
name: "nok, overflow max int8",
when: "128",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "128": value out of range`,
},
{
name: "nok, underflow min int8",
when: "-129",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-129": value out of range`,
},
{
name: "nok, invalid value",
when: "X",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[int8](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_int16(t *testing.T) {
var testCases = []struct {
name string
when string
expect int16
expectErr string
}{
{
name: "ok, 0",
when: "0",
expect: 0,
},
{
name: "ok, 1",
when: "1",
expect: 1,
},
{
name: "ok, -1",
when: "-1",
expect: -1,
},
{
name: "ok, max int16",
when: "32767",
expect: 32767,
},
{
name: "ok, min int16",
when: "-32768",
expect: -32768,
},
{
name: "nok, overflow max int16",
when: "32768",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "32768": value out of range`,
},
{
name: "nok, underflow min int16",
when: "-32769",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-32769": value out of range`,
},
{
name: "nok, invalid value",
when: "X",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[int16](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_int32(t *testing.T) {
var testCases = []struct {
name string
when string
expect int32
expectErr string
}{
{
name: "ok, 0",
when: "0",
expect: 0,
},
{
name: "ok, 1",
when: "1",
expect: 1,
},
{
name: "ok, -1",
when: "-1",
expect: -1,
},
{
name: "ok, max int32",
when: "2147483647",
expect: 2147483647,
},
{
name: "ok, min int32",
when: "-2147483648",
expect: -2147483648,
},
{
name: "nok, overflow max int32",
when: "2147483648",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "2147483648": value out of range`,
},
{
name: "nok, underflow min int32",
when: "-2147483649",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-2147483649": value out of range`,
},
{
name: "nok, invalid value",
when: "X",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[int32](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_int64(t *testing.T) {
var testCases = []struct {
name string
when string
expect int64
expectErr string
}{
{
name: "ok, 0",
when: "0",
expect: 0,
},
{
name: "ok, 1",
when: "1",
expect: 1,
},
{
name: "ok, -1",
when: "-1",
expect: -1,
},
{
name: "ok, max int64",
when: "9223372036854775807",
expect: 9223372036854775807,
},
{
name: "ok, min int64",
when: "-9223372036854775808",
expect: -9223372036854775808,
},
{
name: "nok, overflow max int64",
when: "9223372036854775808",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "9223372036854775808": value out of range`,
},
{
name: "nok, underflow min int64",
when: "-9223372036854775809",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-9223372036854775809": value out of range`,
},
{
name: "nok, invalid value",
when: "X",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[int64](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_uint8(t *testing.T) {
var testCases = []struct {
name string
when string
expect uint8
expectErr string
}{
{
name: "ok, 0",
when: "0",
expect: 0,
},
{
name: "ok, 1",
when: "1",
expect: 1,
},
{
name: "ok, max uint8",
when: "255",
expect: 255,
},
{
name: "nok, overflow max uint8",
when: "256",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "256": value out of range`,
},
{
name: "nok, negative value",
when: "-1",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
},
{
name: "nok, invalid value",
when: "X",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[uint8](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_uint16(t *testing.T) {
var testCases = []struct {
name string
when string
expect uint16
expectErr string
}{
{
name: "ok, 0",
when: "0",
expect: 0,
},
{
name: "ok, 1",
when: "1",
expect: 1,
},
{
name: "ok, max uint16",
when: "65535",
expect: 65535,
},
{
name: "nok, overflow max uint16",
when: "65536",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "65536": value out of range`,
},
{
name: "nok, negative value",
when: "-1",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
},
{
name: "nok, invalid value",
when: "X",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[uint16](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_uint32(t *testing.T) {
var testCases = []struct {
name string
when string
expect uint32
expectErr string
}{
{
name: "ok, 0",
when: "0",
expect: 0,
},
{
name: "ok, 1",
when: "1",
expect: 1,
},
{
name: "ok, max uint32",
when: "4294967295",
expect: 4294967295,
},
{
name: "nok, overflow max uint32",
when: "4294967296",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "4294967296": value out of range`,
},
{
name: "nok, negative value",
when: "-1",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
},
{
name: "nok, invalid value",
when: "X",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[uint32](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_uint64(t *testing.T) {
var testCases = []struct {
name string
when string
expect uint64
expectErr string
}{
{
name: "ok, 0",
when: "0",
expect: 0,
},
{
name: "ok, 1",
when: "1",
expect: 1,
},
{
name: "ok, max uint64",
when: "18446744073709551615",
expect: 18446744073709551615,
},
{
name: "nok, overflow max uint64",
when: "18446744073709551616",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "18446744073709551616": value out of range`,
},
{
name: "nok, negative value",
when: "-1",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
},
{
name: "nok, invalid value",
when: "X",
expect: 0,
expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[uint64](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_string(t *testing.T) {
var testCases = []struct {
name string
when string
expect string
expectErr string
}{
{
name: "ok, my",
when: "my",
expect: "my",
},
{
name: "ok, empty",
when: "",
expect: "",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[string](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_Duration(t *testing.T) {
var testCases = []struct {
name string
when string
expect time.Duration
expectErr string
}{
{
name: "ok, 10h11m01s",
when: "10h11m01s",
expect: 10*time.Hour + 11*time.Minute + 1*time.Second,
},
{
name: "ok, empty",
when: "",
expect: 0,
},
{
name: "ok, invalid",
when: "0x0",
expect: 0,
expectErr: `failed to parse value, err: time: unknown unit "x" in duration "0x0"`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[time.Duration](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_Time(t *testing.T) {
tallinn, err := time.LoadLocation("Europe/Tallinn")
if err != nil {
t.Fatal(err)
}
berlin, err := time.LoadLocation("Europe/Berlin")
if err != nil {
t.Fatal(err)
}
parse := func(t *testing.T, layout string, s string) time.Time {
result, err := time.Parse(layout, s)
if err != nil {
t.Fatal(err)
}
return result
}
parseInLoc := func(t *testing.T, layout string, s string, loc *time.Location) time.Time {
result, err := time.ParseInLocation(layout, s, loc)
if err != nil {
t.Fatal(err)
}
return result
}
var testCases = []struct {
name string
when string
whenLayout TimeLayout
whenTimeOpts *TimeOpts
expect time.Time
expectErr string
}{
{
name: "ok, defaults to RFC3339Nano",
when: "2006-01-02T15:04:05.999999999Z",
expect: parse(t, time.RFC3339Nano, "2006-01-02T15:04:05.999999999Z"),
},
{
name: "ok, custom TimeOpt",
when: "2006-01-02",
whenTimeOpts: &TimeOpts{
Layout: time.DateOnly,
ParseInLocation: tallinn,
ToInLocation: berlin,
},
expect: parseInLoc(t, time.DateTime, "2006-01-01 23:00:00", berlin),
},
{
name: "ok, custom layout",
when: "2006-01-02",
whenLayout: TimeLayout(time.DateOnly),
expect: parse(t, time.DateOnly, "2006-01-02"),
},
{
name: "ok, TimeLayoutUnixTime",
when: "1766604665",
whenLayout: TimeLayoutUnixTime,
expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05Z"),
},
{
name: "nok, TimeLayoutUnixTime, invalid value",
when: "176x6604665",
whenLayout: TimeLayoutUnixTime,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "176x6604665": invalid syntax`,
},
{
name: "ok, TimeLayoutUnixTimeMilli",
when: "1766604665123",
whenLayout: TimeLayoutUnixTimeMilli,
expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05.123Z"),
},
{
name: "nok, TimeLayoutUnixTimeMilli, invalid value",
when: "1x766604665123",
whenLayout: TimeLayoutUnixTimeMilli,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "1x766604665123": invalid syntax`,
},
{
name: "ok, TimeLayoutUnixTimeMilli",
when: "1766604665999999999",
whenLayout: TimeLayoutUnixTimeNano,
expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05.999999999Z"),
},
{
name: "nok, TimeLayoutUnixTimeMilli, invalid value",
when: "1x766604665999999999",
whenLayout: TimeLayoutUnixTimeNano,
expectErr: `failed to parse value, err: strconv.ParseInt: parsing "1x766604665999999999": invalid syntax`,
},
{
name: "ok, invalid",
when: "xx",
expect: time.Time{},
expectErr: `failed to parse value, err: parsing time "xx" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "xx" as "2006"`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var opts []any
if tc.whenLayout != "" {
opts = append(opts, tc.whenLayout)
}
if tc.whenTimeOpts != nil {
opts = append(opts, *tc.whenTimeOpts)
}
v, err := ParseValue[time.Time](tc.when, opts...)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_OptionsOnlyForTime(t *testing.T) {
_, err := ParseValue[int]("test", TimeLayoutUnixTime)
assert.EqualError(t, err, `failed to parse value, err: options are only supported for time.Time, got *int`)
}
func TestParseValue_BindUnmarshaler(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
var testCases = []struct {
name string
when string
expect Timestamp
expectErr string
}{
{
name: "ok",
when: "2020-12-23T09:45:31+02:00",
expect: Timestamp(exampleTime),
},
{
name: "nok, invalid value",
when: "2020-12-23T09:45:3102:00",
expect: Timestamp{},
expectErr: `failed to parse value, err: parsing time "2020-12-23T09:45:3102:00" as "2006-01-02T15:04:05Z07:00": cannot parse "02:00" as "Z07:00"`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[Timestamp](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_TextUnmarshaler(t *testing.T) {
var testCases = []struct {
name string
when string
expect TextUnmarshalerType
expectErr string
}{
{
name: "ok, converts to uppercase",
when: "hello",
expect: TextUnmarshalerType{Value: "HELLO"},
},
{
name: "ok, empty string",
when: "",
expect: TextUnmarshalerType{Value: ""},
},
{
name: "nok, invalid value",
when: "invalid",
expect: TextUnmarshalerType{},
expectErr: "failed to parse value, err: invalid value: invalid",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[TextUnmarshalerType](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValue_JSONUnmarshaler(t *testing.T) {
var testCases = []struct {
name string
when string
expect JSONUnmarshalerType
expectErr string
}{
{
name: "ok, valid JSON string",
when: `"hello"`,
expect: JSONUnmarshalerType{Value: "hello"},
},
{
name: "ok, empty JSON string",
when: `""`,
expect: JSONUnmarshalerType{Value: ""},
},
{
name: "nok, invalid JSON",
when: "not-json",
expect: JSONUnmarshalerType{},
expectErr: "failed to parse value, err: invalid character 'o' in literal null (expecting 'u')",
},
{
name: "nok, unquoted string",
when: "hello",
expect: JSONUnmarshalerType{},
expectErr: "failed to parse value, err: invalid character 'h' looking for beginning of value",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValue[JSONUnmarshalerType](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestParseValues_bools(t *testing.T) {
var testCases = []struct {
name string
when []string
expect []bool
expectErr string
}{
{
name: "ok",
when: []string{"true", "0", "false", "1"},
expect: []bool{true, false, false, true},
},
{
name: "nok",
when: []string{"true", "10"},
expect: nil,
expectErr: `failed to parse value, err: strconv.ParseBool: parsing "10": invalid syntax`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v, err := ParseValues[bool](tc.when)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestPathParamOr(t *testing.T) {
var testCases = []struct {
name string
givenKey string
givenValue string
defaultValue int
expect int
expectErr string
}{
{
name: "ok, param exists",
givenKey: "id",
givenValue: "123",
defaultValue: 999,
expect: 123,
},
{
name: "ok, param missing - returns default",
givenKey: "other",
givenValue: "123",
defaultValue: 999,
expect: 999,
},
{
name: "ok, param exists but empty - returns default",
givenKey: "id",
givenValue: "",
defaultValue: 999,
expect: 999,
},
{
name: "nok, invalid value",
givenKey: "id",
givenValue: "invalid",
defaultValue: 999,
expectErr: "code=400, message=path value, err=failed to parse value",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := NewContext(nil, nil)
c.SetPathValues(PathValues{{Name: tc.givenKey, Value: tc.givenValue}})
v, err := PathParamOr[int](c, "id", tc.defaultValue)
if tc.expectErr != "" {
assert.ErrorContains(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestQueryParamOr(t *testing.T) {
var testCases = []struct {
name string
givenURL string
defaultValue int
expect int
expectErr string
}{
{
name: "ok, param exists",
givenURL: "/?key=42",
defaultValue: 999,
expect: 42,
},
{
name: "ok, param missing - returns default",
givenURL: "/?other=42",
defaultValue: 999,
expect: 999,
},
{
name: "ok, param exists but empty - returns default",
givenURL: "/?key=",
defaultValue: 999,
expect: 999,
},
{
name: "nok, invalid value",
givenURL: "/?key=invalid",
defaultValue: 999,
expectErr: "code=400, message=query param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
c := NewContext(req, nil)
v, err := QueryParamOr[int](c, "key", tc.defaultValue)
if tc.expectErr != "" {
assert.ErrorContains(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestQueryParamsOr(t *testing.T) {
var testCases = []struct {
name string
givenURL string
defaultValue []int
expect []int
expectErr string
}{
{
name: "ok, params exist",
givenURL: "/?key=1&key=2&key=3",
defaultValue: []int{999},
expect: []int{1, 2, 3},
},
{
name: "ok, params missing - returns default",
givenURL: "/?other=1",
defaultValue: []int{7, 8, 9},
expect: []int{7, 8, 9},
},
{
name: "nok, invalid value",
givenURL: "/?key=1&key=invalid",
defaultValue: []int{999},
expectErr: "code=400, message=query params",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
c := NewContext(req, nil)
v, err := QueryParamsOr[int](c, "key", tc.defaultValue)
if tc.expectErr != "" {
assert.ErrorContains(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestFormValueOr(t *testing.T) {
var testCases = []struct {
name string
givenURL string
defaultValue string
expect string
expectErr string
}{
{
name: "ok, value exists",
givenURL: "/?name=john",
defaultValue: "default",
expect: "john",
},
{
name: "ok, value missing - returns default",
givenURL: "/?other=john",
defaultValue: "default",
expect: "default",
},
{
name: "ok, value exists but empty - returns default",
givenURL: "/?name=",
defaultValue: "default",
expect: "default",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
c := NewContext(req, nil)
v, err := FormValueOr[string](c, "name", tc.defaultValue)
if tc.expectErr != "" {
assert.ErrorContains(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
func TestFormValuesOr(t *testing.T) {
var testCases = []struct {
name string
givenURL string
defaultValue []string
expect []string
expectErr string
}{
{
name: "ok, values exist",
givenURL: "/?tags=go&tags=rust&tags=python",
defaultValue: []string{"default"},
expect: []string{"go", "rust", "python"},
},
{
name: "ok, values missing - returns default",
givenURL: "/?other=value",
defaultValue: []string{"a", "b"},
expect: []string{"a", "b"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
c := NewContext(req, nil)
v, err := FormValuesOr[string](c, "tags", tc.defaultValue)
if tc.expectErr != "" {
assert.ErrorContains(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expect, v)
})
}
}
================================================
FILE: binder_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"encoding/json"
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"io"
"math/big"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"
)
func createTestContext(URL string, body io.Reader, pathValues map[string]string) *Context {
e := New()
req := httptest.NewRequest(http.MethodGet, URL, body)
if body != nil {
req.Header.Set(HeaderContentType, MIMEApplicationJSON)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if len(pathValues) > 0 {
params := make(PathValues, 0)
for name, value := range pathValues {
params = append(params, PathValue{
Name: name,
Value: value,
})
}
c.SetPathValues(params)
}
return c
}
func TestBindingError_Error(t *testing.T) {
err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error"))
assert.EqualError(t, err, `code=400, message=bind failed, err=internal error, field=id`)
bErr := err.(*BindingError)
assert.Equal(t, 400, bErr.Code)
assert.Equal(t, "bind failed", bErr.Message)
assert.Equal(t, errors.New("internal error"), bErr.err)
assert.Equal(t, "id", bErr.Field)
assert.Equal(t, []string{"1", "nope"}, bErr.Values)
}
func TestBindingError_ErrorJSON(t *testing.T) {
err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error"))
resp, _ := json.Marshal(err)
assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp))
}
func TestPathValuesBinder(t *testing.T) {
c := createTestContext("/api/user/999", nil, map[string]string{
"id": "1",
"nr": "2",
"slice": "3",
})
b := PathValuesBinder(c)
id := int64(99)
nr := int64(88)
var slice = make([]int64, 0)
var notExisting = make([]int64, 0)
err := b.Int64("id", &id).
Int64("nr", &nr).
Int64s("slice", &slice).
Int64s("not_existing", ¬Existing).
BindError()
assert.NoError(t, err)
assert.Equal(t, int64(1), id)
assert.Equal(t, int64(2), nr)
assert.Equal(t, []int64{3}, slice) // binding params to slice does not make sense but it should not panic either
assert.Equal(t, []int64{}, notExisting) // binding params to slice does not make sense but it should not panic either
}
func TestQueryParamsBinder_FailFast(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError []string
givenFailFast bool
}{
{
name: "ok, FailFast=true stops at first error",
whenURL: "/api/user/999?nr=en&id=nope",
givenFailFast: true,
expectError: []string{
`code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`,
},
},
{
name: "ok, FailFast=false encounters all errors",
whenURL: "/api/user/999?nr=en&id=nope",
givenFailFast: false,
expectError: []string{
`code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`,
`code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "en": invalid syntax, field=nr`,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, map[string]string{"id": "999"})
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
id := int64(99)
nr := int64(88)
errs := b.Int64("id", &id).
Int64("nr", &nr).
BindErrors()
assert.Len(t, errs, len(tc.expectError))
for _, err := range errs {
assert.Contains(t, tc.expectError, err.Error())
}
})
}
}
func TestFormFieldBinder(t *testing.T) {
e := New()
body := `texta=foo&slice=5`
req := httptest.NewRequest(http.MethodPost, "/api/search?id=1&nr=2&slice=3&slice=4", strings.NewReader(body))
req.Header.Set(HeaderContentLength, strconv.Itoa(len(body)))
req.Header.Set(HeaderContentType, MIMEApplicationForm)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
b := FormFieldBinder(c)
var texta string
id := int64(99)
nr := int64(88)
var slice = make([]int64, 0)
var notExisting = make([]int64, 0)
err := b.
Int64s("slice", &slice).
Int64("id", &id).
Int64("nr", &nr).
String("texta", &texta).
Int64s("notExisting", ¬Existing).
BindError()
assert.NoError(t, err)
assert.Equal(t, "foo", texta)
assert.Equal(t, int64(1), id)
assert.Equal(t, int64(2), nr)
assert.Equal(t, []int64{5, 3, 4}, slice)
assert.Equal(t, []int64{}, notExisting)
}
func TestValueBinder_errorStopsBinding(t *testing.T) {
// this test documents "feature" that binding multiple params can change destination if it was bound before
// failing parameter binding
c := createTestContext("/api/user/999?id=1&nr=nope", nil, nil)
b := QueryParamsBinder(c)
id := int64(99) // will be changed before nr binding fails
nr := int64(88) // will not be changed
err := b.Int64("id", &id).
Int64("nr", &nr).
BindError()
assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=nr")
assert.Equal(t, int64(1), id)
assert.Equal(t, int64(88), nr)
}
func TestValueBinder_BindError(t *testing.T) {
c := createTestContext("/api/user/999?nr=en&id=nope", nil, nil)
b := QueryParamsBinder(c)
id := int64(99)
nr := int64(88)
err := b.Int64("id", &id).
Int64("nr", &nr).
BindError()
assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=id")
assert.Nil(t, b.errors)
assert.Nil(t, b.BindError())
}
func TestValueBinder_GetValues(t *testing.T) {
var testCases = []struct {
whenValuesFunc func(sourceParam string) []string
name string
expectError string
expect []int64
}{
{
name: "ok, default implementation",
expect: []int64{1, 101},
},
{
name: "ok, values returns nil",
whenValuesFunc: func(sourceParam string) []string {
return nil
},
expect: []int64(nil),
},
{
name: "ok, values returns empty slice",
whenValuesFunc: func(sourceParam string) []string {
return []string{}
},
expect: []int64(nil),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext("/search?nr=en&id=1&id=101", nil, nil)
b := QueryParamsBinder(c)
if tc.whenValuesFunc != nil {
b.ValuesFunc = tc.whenValuesFunc
}
var IDs []int64
err := b.Int64s("id", &IDs).BindError()
assert.Equal(t, tc.expect, IDs)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_CustomFuncWithError(t *testing.T) {
c := createTestContext("/search?nr=en&id=1&id=101", nil, nil)
b := QueryParamsBinder(c)
id := int64(99)
givenCustomFunc := func(values []string) []error {
assert.Equal(t, []string{"1", "101"}, values)
return []error{
errors.New("first error"),
errors.New("second error"),
}
}
err := b.CustomFunc("id", givenCustomFunc).BindError()
assert.Equal(t, int64(99), id)
assert.EqualError(t, err, "first error")
}
func TestValueBinder_CustomFunc(t *testing.T) {
var testCases = []struct {
expectValue any
name string
whenURL string
givenFuncErrors []error
expectParamValues []string
expectErrors []string
givenFailFast bool
}{
{
name: "ok, binds value",
whenURL: "/search?nr=en&id=1&id=100",
expectParamValues: []string{"1", "100"},
expectValue: int64(1000),
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nr=en",
expectParamValues: []string{},
expectValue: int64(99),
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?nr=en&id=1&id=100",
expectParamValues: []string{"1", "100"},
expectValue: int64(99),
expectErrors: []string{"previous error"},
},
{
name: "nok, func returns errors",
givenFuncErrors: []error{
errors.New("first error"),
errors.New("second error"),
},
whenURL: "/search?nr=en&id=1&id=100",
expectParamValues: []string{"1", "100"},
expectValue: int64(99),
expectErrors: []string{"first error", "second error"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
id := int64(99)
givenCustomFunc := func(values []string) []error {
assert.Equal(t, tc.expectParamValues, values)
if tc.givenFuncErrors == nil {
id = 1000 // emulated conversion and setting value
return nil
}
return tc.givenFuncErrors
}
errs := b.CustomFunc("id", givenCustomFunc).BindErrors()
assert.Equal(t, tc.expectValue, id)
if tc.expectErrors != nil {
assert.Len(t, errs, len(tc.expectErrors))
for _, err := range errs {
assert.Contains(t, tc.expectErrors, err.Error())
}
} else {
assert.Nil(t, errs)
}
})
}
}
func TestValueBinder_MustCustomFunc(t *testing.T) {
var testCases = []struct {
expectValue any
name string
whenURL string
givenFuncErrors []error
expectParamValues []string
expectErrors []string
givenFailFast bool
}{
{
name: "ok, binds value",
whenURL: "/search?nr=en&id=1&id=100",
expectParamValues: []string{"1", "100"},
expectValue: int64(1000),
},
{
name: "nok, params values empty, returns error, value is not changed",
whenURL: "/search?nr=en",
expectParamValues: []string{},
expectValue: int64(99),
expectErrors: []string{"code=400, message=required field value is empty, field=id"},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?nr=en&id=1&id=100",
expectParamValues: []string{"1", "100"},
expectValue: int64(99),
expectErrors: []string{"previous error"},
},
{
name: "nok, func returns errors",
givenFuncErrors: []error{
errors.New("first error"),
errors.New("second error"),
},
whenURL: "/search?nr=en&id=1&id=100",
expectParamValues: []string{"1", "100"},
expectValue: int64(99),
expectErrors: []string{"first error", "second error"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
id := int64(99)
givenCustomFunc := func(values []string) []error {
assert.Equal(t, tc.expectParamValues, values)
if tc.givenFuncErrors == nil {
id = 1000 // emulated conversion and setting value
return nil
}
return tc.givenFuncErrors
}
errs := b.MustCustomFunc("id", givenCustomFunc).BindErrors()
assert.Equal(t, tc.expectValue, id)
if tc.expectErrors != nil {
assert.Len(t, errs, len(tc.expectErrors))
for _, err := range errs {
assert.Contains(t, tc.expectErrors, err.Error())
}
} else {
assert.Nil(t, errs)
}
})
}
}
func TestValueBinder_String(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectValue string
expectError string
givenBindErrors []error
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=en¶m=de",
expectValue: "en",
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nr=en",
expectValue: "default",
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?nr=en&id=1&id=100",
expectValue: "default",
expectError: "previous error",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=en¶m=de",
expectValue: "en",
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nr=en",
expectValue: "default",
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?nr=en&id=1&id=100",
expectValue: "default",
expectError: "previous error",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := "default"
var err error
if tc.whenMust {
err = b.MustString("param", &dest).BindError()
} else {
err = b.String("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Strings(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue []string
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=en¶m=de",
expectValue: []string{"en", "de"},
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nr=en",
expectValue: []string{"default"},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?nr=en&id=1&id=100",
expectValue: []string{"default"},
expectError: "previous error",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=en¶m=de",
expectValue: []string{"en", "de"},
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nr=en",
expectValue: []string{"default"},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?nr=en&id=1&id=100",
expectValue: []string{"default"},
expectError: "previous error",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := []string{"default"}
var err error
if tc.whenMust {
err = b.MustStrings("param", &dest).BindError()
} else {
err = b.Strings("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Int64_intValue(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue int64
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=1¶m=100",
expectValue: 1,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: 99,
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: 99,
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: 99,
expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: 1,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: 99,
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: 99,
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: 99,
expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := int64(99)
var err error
if tc.whenMust {
err = b.MustInt64("param", &dest).BindError()
} else {
err = b.Int64("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Int_errorMessage(t *testing.T) {
// int/uint (without byte size) has a little bit different error message so test these separately
c := createTestContext("/search?param=nope", nil, nil)
b := QueryParamsBinder(c).FailFast(false)
destInt := 99
destUint := uint(98)
errs := b.Int("param", &destInt).Uint("param", &destUint).BindErrors()
assert.Equal(t, 99, destInt)
assert.Equal(t, uint(98), destUint)
assert.EqualError(t, errs[0], `code=400, message=failed to bind field value to int, err=strconv.ParseInt: parsing "nope": invalid syntax, field=param`)
assert.EqualError(t, errs[1], `code=400, message=failed to bind field value to uint, err=strconv.ParseUint: parsing "nope": invalid syntax, field=param`)
}
func TestValueBinder_Uint64_uintValue(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue uint64
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=1¶m=100",
expectValue: 1,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: 99,
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: 99,
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: 99,
expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: 1,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: 99,
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: 99,
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: 99,
expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := uint64(99)
var err error
if tc.whenMust {
err = b.MustUint64("param", &dest).BindError()
} else {
err = b.Uint64("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Int_Types(t *testing.T) {
type target struct {
int64 int64
mustInt64 int64
uint64 uint64
mustUint64 uint64
int32 int32
mustInt32 int32
uint32 uint32
mustUint32 uint32
int16 int16
mustInt16 int16
uint16 uint16
mustUint16 uint16
int8 int8
mustInt8 int8
uint8 uint8
mustUint8 uint8
byte byte
mustByte byte
int int
mustInt int
uint uint
mustUint uint
}
types := []string{
"int64=1",
"mustInt64=2",
"uint64=3",
"mustUint64=4",
"int32=5",
"mustInt32=6",
"uint32=7",
"mustUint32=8",
"int16=9",
"mustInt16=10",
"uint16=11",
"mustUint16=12",
"int8=13",
"mustInt8=14",
"uint8=15",
"mustUint8=16",
"byte=17",
"mustByte=18",
"int=19",
"mustInt=20",
"uint=21",
"mustUint=22",
}
c := createTestContext("/search?"+strings.Join(types, "&"), nil, nil)
b := QueryParamsBinder(c)
dest := target{}
err := b.
Int64("int64", &dest.int64).
MustInt64("mustInt64", &dest.mustInt64).
Uint64("uint64", &dest.uint64).
MustUint64("mustUint64", &dest.mustUint64).
Int32("int32", &dest.int32).
MustInt32("mustInt32", &dest.mustInt32).
Uint32("uint32", &dest.uint32).
MustUint32("mustUint32", &dest.mustUint32).
Int16("int16", &dest.int16).
MustInt16("mustInt16", &dest.mustInt16).
Uint16("uint16", &dest.uint16).
MustUint16("mustUint16", &dest.mustUint16).
Int8("int8", &dest.int8).
MustInt8("mustInt8", &dest.mustInt8).
Uint8("uint8", &dest.uint8).
MustUint8("mustUint8", &dest.mustUint8).
Byte("byte", &dest.byte).
MustByte("mustByte", &dest.mustByte).
Int("int", &dest.int).
MustInt("mustInt", &dest.mustInt).
Uint("uint", &dest.uint).
MustUint("mustUint", &dest.mustUint).
BindError()
assert.NoError(t, err)
assert.Equal(t, int64(1), dest.int64)
assert.Equal(t, int64(2), dest.mustInt64)
assert.Equal(t, uint64(3), dest.uint64)
assert.Equal(t, uint64(4), dest.mustUint64)
assert.Equal(t, int32(5), dest.int32)
assert.Equal(t, int32(6), dest.mustInt32)
assert.Equal(t, uint32(7), dest.uint32)
assert.Equal(t, uint32(8), dest.mustUint32)
assert.Equal(t, int16(9), dest.int16)
assert.Equal(t, int16(10), dest.mustInt16)
assert.Equal(t, uint16(11), dest.uint16)
assert.Equal(t, uint16(12), dest.mustUint16)
assert.Equal(t, int8(13), dest.int8)
assert.Equal(t, int8(14), dest.mustInt8)
assert.Equal(t, uint8(15), dest.uint8)
assert.Equal(t, uint8(16), dest.mustUint8)
assert.Equal(t, uint8(17), dest.byte)
assert.Equal(t, uint8(18), dest.mustByte)
assert.Equal(t, 19, dest.int)
assert.Equal(t, 20, dest.mustInt)
assert.Equal(t, uint(21), dest.uint)
assert.Equal(t, uint(22), dest.mustUint)
}
func TestValueBinder_Int64s_intsValue(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue []int64
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=1¶m=2¶m=1",
expectValue: []int64{1, 2, 1},
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: []int64{99},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: []int64{99},
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []int64{99},
expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=1¶m=2¶m=1",
expectValue: []int64{1, 2, 1},
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: []int64{99},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: []int64{99},
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []int64{99},
expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := []int64{99} // when values are set with bind - contents before bind is gone
var err error
if tc.whenMust {
err = b.MustInt64s("param", &dest).BindError()
} else {
err = b.Int64s("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Uint64s_uintsValue(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue []uint64
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=1¶m=2¶m=1",
expectValue: []uint64{1, 2, 1},
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: []uint64{99},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: []uint64{99},
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []uint64{99},
expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=1¶m=2¶m=1",
expectValue: []uint64{1, 2, 1},
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: []uint64{99},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: []uint64{99},
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []uint64{99},
expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := []uint64{99} // when values are set with bind - contents before bind is gone
var err error
if tc.whenMust {
err = b.MustUint64s("param", &dest).BindError()
} else {
err = b.Uint64s("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Ints_Types(t *testing.T) {
type target struct {
int64 []int64
mustInt64 []int64
uint64 []uint64
mustUint64 []uint64
int32 []int32
mustInt32 []int32
uint32 []uint32
mustUint32 []uint32
int16 []int16
mustInt16 []int16
uint16 []uint16
mustUint16 []uint16
int8 []int8
mustInt8 []int8
uint8 []uint8
mustUint8 []uint8
int []int
mustInt []int
uint []uint
mustUint []uint
}
types := []string{
"int64=1",
"mustInt64=2",
"uint64=3",
"mustUint64=4",
"int32=5",
"mustInt32=6",
"uint32=7",
"mustUint32=8",
"int16=9",
"mustInt16=10",
"uint16=11",
"mustUint16=12",
"int8=13",
"mustInt8=14",
"uint8=15",
"mustUint8=16",
"int=19",
"mustInt=20",
"uint=21",
"mustUint=22",
}
url := "/search?"
for _, v := range types {
url = url + "&" + v + "&" + v
}
c := createTestContext(url, nil, nil)
b := QueryParamsBinder(c)
dest := target{}
err := b.
Int64s("int64", &dest.int64).
MustInt64s("mustInt64", &dest.mustInt64).
Uint64s("uint64", &dest.uint64).
MustUint64s("mustUint64", &dest.mustUint64).
Int32s("int32", &dest.int32).
MustInt32s("mustInt32", &dest.mustInt32).
Uint32s("uint32", &dest.uint32).
MustUint32s("mustUint32", &dest.mustUint32).
Int16s("int16", &dest.int16).
MustInt16s("mustInt16", &dest.mustInt16).
Uint16s("uint16", &dest.uint16).
MustUint16s("mustUint16", &dest.mustUint16).
Int8s("int8", &dest.int8).
MustInt8s("mustInt8", &dest.mustInt8).
Uint8s("uint8", &dest.uint8).
MustUint8s("mustUint8", &dest.mustUint8).
Ints("int", &dest.int).
MustInts("mustInt", &dest.mustInt).
Uints("uint", &dest.uint).
MustUints("mustUint", &dest.mustUint).
BindError()
assert.NoError(t, err)
assert.Equal(t, []int64{1, 1}, dest.int64)
assert.Equal(t, []int64{2, 2}, dest.mustInt64)
assert.Equal(t, []uint64{3, 3}, dest.uint64)
assert.Equal(t, []uint64{4, 4}, dest.mustUint64)
assert.Equal(t, []int32{5, 5}, dest.int32)
assert.Equal(t, []int32{6, 6}, dest.mustInt32)
assert.Equal(t, []uint32{7, 7}, dest.uint32)
assert.Equal(t, []uint32{8, 8}, dest.mustUint32)
assert.Equal(t, []int16{9, 9}, dest.int16)
assert.Equal(t, []int16{10, 10}, dest.mustInt16)
assert.Equal(t, []uint16{11, 11}, dest.uint16)
assert.Equal(t, []uint16{12, 12}, dest.mustUint16)
assert.Equal(t, []int8{13, 13}, dest.int8)
assert.Equal(t, []int8{14, 14}, dest.mustInt8)
assert.Equal(t, []uint8{15, 15}, dest.uint8)
assert.Equal(t, []uint8{16, 16}, dest.mustUint8)
assert.Equal(t, []int{19, 19}, dest.int)
assert.Equal(t, []int{20, 20}, dest.mustInt)
assert.Equal(t, []uint{21, 21}, dest.uint)
assert.Equal(t, []uint{22, 22}, dest.mustUint)
}
func TestValueBinder_Ints_Types_FailFast(t *testing.T) {
// FailFast() should stop parsing and return early
errTmpl := "code=400, message=failed to bind field value to %v, err=strconv.Parse%v: parsing \"nope\": invalid syntax, field=param"
c := createTestContext("/search?param=1¶m=nope¶m=2", nil, nil)
var dest64 []int64
err := QueryParamsBinder(c).FailFast(true).Int64s("param", &dest64).BindError()
assert.Equal(t, []int64(nil), dest64)
assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int64", "Int"))
var dest32 []int32
err = QueryParamsBinder(c).FailFast(true).Int32s("param", &dest32).BindError()
assert.Equal(t, []int32(nil), dest32)
assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int32", "Int"))
var dest16 []int16
err = QueryParamsBinder(c).FailFast(true).Int16s("param", &dest16).BindError()
assert.Equal(t, []int16(nil), dest16)
assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int16", "Int"))
var dest8 []int8
err = QueryParamsBinder(c).FailFast(true).Int8s("param", &dest8).BindError()
assert.Equal(t, []int8(nil), dest8)
assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int8", "Int"))
var dest []int
err = QueryParamsBinder(c).FailFast(true).Ints("param", &dest).BindError()
assert.Equal(t, []int(nil), dest)
assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int", "Int"))
var destu64 []uint64
err = QueryParamsBinder(c).FailFast(true).Uint64s("param", &destu64).BindError()
assert.Equal(t, []uint64(nil), destu64)
assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint64", "Uint"))
var destu32 []uint32
err = QueryParamsBinder(c).FailFast(true).Uint32s("param", &destu32).BindError()
assert.Equal(t, []uint32(nil), destu32)
assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint32", "Uint"))
var destu16 []uint16
err = QueryParamsBinder(c).FailFast(true).Uint16s("param", &destu16).BindError()
assert.Equal(t, []uint16(nil), destu16)
assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint16", "Uint"))
var destu8 []uint8
err = QueryParamsBinder(c).FailFast(true).Uint8s("param", &destu8).BindError()
assert.Equal(t, []uint8(nil), destu8)
assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint8", "Uint"))
var destu []uint
err = QueryParamsBinder(c).FailFast(true).Uints("param", &destu).BindError()
assert.Equal(t, []uint(nil), destu)
assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint", "Uint"))
}
func TestValueBinder_Bool(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
givenFailFast bool
whenMust bool
expectValue bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=true¶m=1",
expectValue: true,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: false,
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: false,
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: false,
expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: true,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: false,
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: false,
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: false,
expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := false
var err error
if tc.whenMust {
err = b.MustBool("param", &dest).BindError()
} else {
err = b.Bool("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Bools(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue []bool
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=true¶m=false¶m=1¶m=0",
expectValue: []bool{true, false, true, false},
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: []bool(nil),
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
givenBindErrors: []error{errors.New("previous error")},
whenURL: "/search?param=1¶m=100",
expectValue: []bool(nil),
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=true¶m=nope¶m=100",
expectValue: []bool(nil),
expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
},
{
name: "nok, conversion fails fast, value is not changed",
givenFailFast: true,
whenURL: "/search?param=true¶m=nope¶m=100",
expectValue: []bool(nil),
expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=true¶m=false¶m=1¶m=0",
expectValue: []bool{true, false, true, false},
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: []bool(nil),
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
givenBindErrors: []error{errors.New("previous error")},
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: []bool(nil),
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []bool(nil),
expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
b.errors = tc.givenBindErrors
var dest []bool
var err error
if tc.whenMust {
err = b.MustBools("param", &dest).BindError()
} else {
err = b.Bools("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Float64(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue float64
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=4.3¶m=1",
expectValue: 4.3,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: 1.123,
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: 1.123,
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: 1.123,
expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=4.3¶m=100",
expectValue: 4.3,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: 1.123,
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: 1.123,
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: 1.123,
expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := 1.123
var err error
if tc.whenMust {
err = b.MustFloat64("param", &dest).BindError()
} else {
err = b.Float64("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Float64s(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue []float64
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=4.3¶m=0",
expectValue: []float64{4.3, 0},
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: []float64(nil),
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
givenBindErrors: []error{errors.New("previous error")},
whenURL: "/search?param=1¶m=100",
expectValue: []float64(nil),
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []float64(nil),
expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
{
name: "nok, conversion fails fast, value is not changed",
givenFailFast: true,
whenURL: "/search?param=0¶m=nope¶m=100",
expectValue: []float64(nil),
expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=4.3¶m=0",
expectValue: []float64{4.3, 0},
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: []float64(nil),
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
givenBindErrors: []error{errors.New("previous error")},
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: []float64(nil),
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []float64(nil),
expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
b.errors = tc.givenBindErrors
var dest []float64
var err error
if tc.whenMust {
err = b.MustFloat64s("param", &dest).BindError()
} else {
err = b.Float64s("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Float32(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue float32
givenNoFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=4.3¶m=1",
expectValue: 4.3,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: 1.123,
},
{
name: "nok, previous errors fail fast without binding value",
givenNoFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: 1.123,
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: 1.123,
expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=4.3¶m=100",
expectValue: 4.3,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: 1.123,
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenNoFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: 1.123,
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: 1.123,
expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenNoFailFast)
if tc.givenNoFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := float32(1.123)
var err error
if tc.whenMust {
err = b.MustFloat32("param", &dest).BindError()
} else {
err = b.Float32("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Float32s(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue []float32
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=4.3¶m=0",
expectValue: []float32{4.3, 0},
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: []float32(nil),
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
givenBindErrors: []error{errors.New("previous error")},
whenURL: "/search?param=1¶m=100",
expectValue: []float32(nil),
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []float32(nil),
expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
{
name: "nok, conversion fails fast, value is not changed",
givenFailFast: true,
whenURL: "/search?param=0¶m=nope¶m=100",
expectValue: []float32(nil),
expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=4.3¶m=0",
expectValue: []float32{4.3, 0},
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: []float32(nil),
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
givenBindErrors: []error{errors.New("previous error")},
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: []float32(nil),
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []float32(nil),
expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
b.errors = tc.givenBindErrors
var dest []float32
var err error
if tc.whenMust {
err = b.MustFloat32s("param", &dest).BindError()
} else {
err = b.Float32s("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Time(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
var testCases = []struct {
expectValue time.Time
name string
whenURL string
whenLayout string
expectError string
givenBindErrors []error
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
whenLayout: time.RFC3339,
expectValue: exampleTime,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: time.Time{},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: time.Time{},
expectError: "previous error",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
whenLayout: time.RFC3339,
expectValue: exampleTime,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: time.Time{},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: time.Time{},
expectError: "previous error",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := time.Time{}
var err error
if tc.whenMust {
err = b.MustTime("param", &dest, tc.whenLayout).BindError()
} else {
err = b.Time("param", &dest, tc.whenLayout).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Times(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
exampleTime2, _ := time.Parse(time.RFC3339, "2000-01-02T09:45:31+00:00")
var testCases = []struct {
name string
whenURL string
whenLayout string
expectError string
givenBindErrors []error
expectValue []time.Time
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
whenLayout: time.RFC3339,
expectValue: []time.Time{exampleTime, exampleTime2},
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: []time.Time(nil),
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
givenBindErrors: []error{errors.New("previous error")},
whenURL: "/search?param=1¶m=100",
expectValue: []time.Time(nil),
expectError: "previous error",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
whenLayout: time.RFC3339,
expectValue: []time.Time{exampleTime, exampleTime2},
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: []time.Time(nil),
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
givenBindErrors: []error{errors.New("previous error")},
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: []time.Time(nil),
expectError: "previous error",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
b.errors = tc.givenBindErrors
layout := time.RFC3339
if tc.whenLayout != "" {
layout = tc.whenLayout
}
var dest []time.Time
var err error
if tc.whenMust {
err = b.MustTimes("param", &dest, layout).BindError()
} else {
err = b.Times("param", &dest, layout).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Duration(t *testing.T) {
example := 42 * time.Second
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue time.Duration
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=42s¶m=1ms",
expectValue: example,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: 0,
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: 0,
expectError: "previous error",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=42s¶m=1ms",
expectValue: example,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: 0,
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: 0,
expectError: "previous error",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
var dest time.Duration
var err error
if tc.whenMust {
err = b.MustDuration("param", &dest).BindError()
} else {
err = b.Duration("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_Durations(t *testing.T) {
exampleDuration := 42 * time.Second
exampleDuration2 := 1 * time.Millisecond
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue []time.Duration
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=42s¶m=1ms",
expectValue: []time.Duration{exampleDuration, exampleDuration2},
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: []time.Duration(nil),
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
givenBindErrors: []error{errors.New("previous error")},
whenURL: "/search?param=1¶m=100",
expectValue: []time.Duration(nil),
expectError: "previous error",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=42s¶m=1ms",
expectValue: []time.Duration{exampleDuration, exampleDuration2},
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: []time.Duration(nil),
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
givenBindErrors: []error{errors.New("previous error")},
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: []time.Duration(nil),
expectError: "previous error",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
b.errors = tc.givenBindErrors
var dest []time.Duration
var err error
if tc.whenMust {
err = b.MustDurations("param", &dest).BindError()
} else {
err = b.Durations("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_BindUnmarshaler(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
var testCases = []struct {
expectValue Timestamp
name string
whenURL string
expectError string
givenBindErrors []error
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
expectValue: Timestamp(exampleTime),
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: Timestamp{},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: Timestamp{},
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: Timestamp{},
expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
expectValue: Timestamp(exampleTime),
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: Timestamp{},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: Timestamp{},
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: Timestamp{},
expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
var dest Timestamp
var err error
if tc.whenMust {
err = b.MustBindUnmarshaler("param", &dest).BindError()
} else {
err = b.BindUnmarshaler("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_JSONUnmarshaler(t *testing.T) {
example := big.NewInt(999)
var testCases = []struct {
name string
whenURL string
expectError string
expectValue big.Int
givenBindErrors []error
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=999¶m=998",
expectValue: *example,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: big.Int{},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: big.Int{},
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=xxx",
expectValue: big.Int{},
expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=999¶m=998",
expectValue: *example,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: big.Int{},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=xxx",
expectValue: big.Int{},
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=xxx",
expectValue: big.Int{},
expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
var dest big.Int
var err error
if tc.whenMust {
err = b.MustJSONUnmarshaler("param", &dest).BindError()
} else {
err = b.JSONUnmarshaler("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_TextUnmarshaler(t *testing.T) {
example := big.NewInt(999)
var testCases = []struct {
name string
whenURL string
expectError string
expectValue big.Int
givenBindErrors []error
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=999¶m=998",
expectValue: *example,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: big.Int{},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: big.Int{},
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=xxx",
expectValue: big.Int{},
expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=999¶m=998",
expectValue: *example,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: big.Int{},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=xxx",
expectValue: big.Int{},
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=xxx",
expectValue: big.Int{},
expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
var dest big.Int
var err error
if tc.whenMust {
err = b.MustTextUnmarshaler("param", &dest).BindError()
} else {
err = b.TextUnmarshaler("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_BindWithDelimiter_types(t *testing.T) {
var testCases = []struct {
expect any
name string
whenURL string
}{
{
name: "ok, strings",
expect: []string{"1", "2", "1"},
},
{
name: "ok, int64",
expect: []int64{1, 2, 1},
},
{
name: "ok, int32",
expect: []int32{1, 2, 1},
},
{
name: "ok, int16",
expect: []int16{1, 2, 1},
},
{
name: "ok, int8",
expect: []int8{1, 2, 1},
},
{
name: "ok, int",
expect: []int{1, 2, 1},
},
{
name: "ok, uint64",
expect: []uint64{1, 2, 1},
},
{
name: "ok, uint32",
expect: []uint32{1, 2, 1},
},
{
name: "ok, uint16",
expect: []uint16{1, 2, 1},
},
{
name: "ok, uint8",
expect: []uint8{1, 2, 1},
},
{
name: "ok, uint",
expect: []uint{1, 2, 1},
},
{
name: "ok, float64",
expect: []float64{1, 2, 1},
},
{
name: "ok, float32",
expect: []float32{1, 2, 1},
},
{
name: "ok, bool",
whenURL: "/search?param=1,false¶m=true",
expect: []bool{true, false, true},
},
{
name: "ok, Duration",
whenURL: "/search?param=1s,42s¶m=1ms",
expect: []time.Duration{1 * time.Second, 42 * time.Second, 1 * time.Millisecond},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
URL := "/search?param=1,2¶m=1"
if tc.whenURL != "" {
URL = tc.whenURL
}
c := createTestContext(URL, nil, nil)
b := QueryParamsBinder(c)
switch tc.expect.(type) {
case []string:
var dest []string
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []int64:
var dest []int64
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []int32:
var dest []int32
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []int16:
var dest []int16
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []int8:
var dest []int8
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []int:
var dest []int
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []uint64:
var dest []uint64
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []uint32:
var dest []uint32
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []uint16:
var dest []uint16
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []uint8:
var dest []uint8
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []uint:
var dest []uint
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []float64:
var dest []float64
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []float32:
var dest []float32
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []bool:
var dest []bool
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
case []time.Duration:
var dest []time.Duration
assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
assert.Equal(t, tc.expect, dest)
default:
assert.Fail(t, "invalid type")
}
})
}
}
func TestValueBinder_BindWithDelimiter(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue []int64
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value",
whenURL: "/search?param=1,2¶m=1",
expectValue: []int64{1, 2, 1},
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: []int64(nil),
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: []int64(nil),
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []int64(nil),
expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=1,2¶m=1",
expectValue: []int64{1, 2, 1},
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: []int64(nil),
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: []int64(nil),
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []int64(nil),
expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
var dest []int64
var err error
if tc.whenMust {
err = b.MustBindWithDelimiter("param", &dest, ",").BindError()
} else {
err = b.BindWithDelimiter("param", &dest, ",").BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestBindWithDelimiter_invalidType(t *testing.T) {
c := createTestContext("/search?param=1¶m=100", nil, nil)
b := QueryParamsBinder(c)
var dest []BindUnmarshaler
err := b.BindWithDelimiter("param", &dest, ",").BindError()
assert.Equal(t, []BindUnmarshaler(nil), dest)
assert.EqualError(t, err, "code=400, message=unsupported bind type, field=param")
}
func TestValueBinder_UnixTime(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43+00:00") // => 1609180603
var testCases = []struct {
expectValue time.Time
name string
whenURL string
expectError string
givenBindErrors []error
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value, unix time in seconds",
whenURL: "/search?param=1609180603¶m=1609180604",
expectValue: exampleTime,
},
{
name: "ok, binds value, unix time over int32 value",
whenURL: "/search?param=2147483648¶m=1609180604",
expectValue: time.Unix(2147483648, 0),
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: time.Time{},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: time.Time{},
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=1609180603¶m=1609180604",
expectValue: exampleTime,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: time.Time{},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: time.Time{},
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := time.Time{}
var err error
if tc.whenMust {
err = b.MustUnixTime("param", &dest).BindError()
} else {
err = b.UnixTime("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano())
assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC))
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_UnixTimeMilli(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140
var testCases = []struct {
expectValue time.Time
name string
whenURL string
expectError string
givenBindErrors []error
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value, unix time in milliseconds",
whenURL: "/search?param=1647184410140¶m=1647184410199",
expectValue: exampleTime,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: time.Time{},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: time.Time{},
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=1647184410140¶m=1647184410199",
expectValue: exampleTime,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: time.Time{},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: time.Time{},
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := time.Time{}
var err error
if tc.whenMust {
err = b.MustUnixTimeMilli("param", &dest).BindError()
} else {
err = b.UnixTimeMilli("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano())
assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC))
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_UnixTimeNano(t *testing.T) {
exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43.000000000+00:00") // => 1609180603
exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789
exampleTimeNanoBelowSec, _ := time.Parse(time.RFC3339Nano, "1970-01-01T00:00:00.999999999+00:00")
var testCases = []struct {
expectValue time.Time
name string
whenURL string
expectError string
givenBindErrors []error
givenFailFast bool
whenMust bool
}{
{
name: "ok, binds value, unix time in nano seconds (sec precision)",
whenURL: "/search?param=1609180603000000000¶m=1609180604",
expectValue: exampleTime,
},
{
name: "ok, binds value, unix time in nano seconds",
whenURL: "/search?param=1609180603123456789¶m=1609180604",
expectValue: exampleTimeNano,
},
{
name: "ok, binds value, unix time in nano seconds (below 1 sec)",
whenURL: "/search?param=999999999¶m=1609180604",
expectValue: exampleTimeNanoBelowSec,
},
{
name: "ok, params values empty, value is not changed",
whenURL: "/search?nope=1",
expectValue: time.Time{},
},
{
name: "nok, previous errors fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: time.Time{},
expectError: "previous error",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
{
name: "ok (must), binds value",
whenMust: true,
whenURL: "/search?param=1609180603000000000¶m=1609180604",
expectValue: exampleTime,
},
{
name: "ok (must), params values empty, returns error, value is not changed",
whenMust: true,
whenURL: "/search?nope=1",
expectValue: time.Time{},
expectError: "code=400, message=required field value is empty, field=param",
},
{
name: "nok (must), previous errors fail fast without binding value",
givenFailFast: true,
whenMust: true,
whenURL: "/search?param=1¶m=100",
expectValue: time.Time{},
expectError: "previous error",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := time.Time{}
var err error
if tc.whenMust {
err = b.MustUnixTimeNano("param", &dest).BindError()
} else {
err = b.UnixTimeNano("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano())
assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC))
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) {
type Opts struct {
Param int64 `query:"param"`
}
c := createTestContext("/search?param=1¶m=100", nil, nil)
b.ReportAllocs()
b.ResetTimer()
binder := new(DefaultBinder)
for i := 0; i < b.N; i++ {
var dest Opts
_ = binder.Bind(c, &dest)
}
}
func BenchmarkValueBinder_BindInt64_single(b *testing.B) {
c := createTestContext("/search?param=1¶m=100", nil, nil)
b.ReportAllocs()
b.ResetTimer()
type Opts struct {
Param int64
}
binder := QueryParamsBinder(c)
for i := 0; i < b.N; i++ {
var dest Opts
_ = binder.Int64("param", &dest.Param).BindError()
}
}
func BenchmarkRawFunc_Int64_single(b *testing.B) {
c := createTestContext("/search?param=1¶m=100", nil, nil)
rawFunc := func(input string, defaultValue int64) (int64, bool) {
if input == "" {
return defaultValue, true
}
n, err := strconv.Atoi(input)
if err != nil {
return 0, false
}
return int64(n), true
}
b.ReportAllocs()
b.ResetTimer()
type Opts struct {
Param int64
}
for i := 0; i < b.N; i++ {
var dest Opts
if n, ok := rawFunc(c.QueryParam("param"), 1); ok {
dest.Param = n
}
}
}
func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) {
type Opts struct {
String string `query:"string"`
Strings []string `query:"strings"`
Int64 int64 `query:"int64"`
Uint64 uint64 `query:"uint64"`
Int32 int32 `query:"int32"`
Uint32 uint32 `query:"uint32"`
Int16 int16 `query:"int16"`
Uint16 uint16 `query:"uint16"`
Int8 int8 `query:"int8"`
Uint8 uint8 `query:"uint8"`
}
c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil)
b.ReportAllocs()
b.ResetTimer()
binder := new(DefaultBinder)
for i := 0; i < b.N; i++ {
var dest Opts
_ = binder.Bind(c, &dest)
if dest.Int64 != 1 {
b.Fatalf("int64!=1")
}
}
}
func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) {
type Opts struct {
String string `query:"string"`
Strings []string `query:"strings"`
Int64 int64 `query:"int64"`
Uint64 uint64 `query:"uint64"`
Int32 int32 `query:"int32"`
Uint32 uint32 `query:"uint32"`
Int16 int16 `query:"int16"`
Uint16 uint16 `query:"uint16"`
Int8 int8 `query:"int8"`
Uint8 uint8 `query:"uint8"`
}
c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil)
b.ReportAllocs()
b.ResetTimer()
binder := QueryParamsBinder(c)
for i := 0; i < b.N; i++ {
var dest Opts
_ = binder.
Int64("int64", &dest.Int64).
Int32("int32", &dest.Int32).
Int16("int16", &dest.Int16).
Int8("int8", &dest.Int8).
String("string", &dest.String).
Uint64("int64", &dest.Uint64).
Uint32("int32", &dest.Uint32).
Uint16("int16", &dest.Uint16).
Uint8("int8", &dest.Uint8).
Strings("strings", &dest.Strings).
BindError()
if dest.Int64 != 1 {
b.Fatalf("int64!=1")
}
}
}
func TestValueBinder_TimeError(t *testing.T) {
var testCases = []struct {
expectValue time.Time
name string
whenURL string
whenLayout string
expectError string
givenBindErrors []error
givenFailFast bool
whenMust bool
}{
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: time.Time{},
expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
dest := time.Time{}
var err error
if tc.whenMust {
err = b.MustTime("param", &dest, tc.whenLayout).BindError()
} else {
err = b.Time("param", &dest, tc.whenLayout).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_TimesError(t *testing.T) {
var testCases = []struct {
name string
whenURL string
whenLayout string
expectError string
givenBindErrors []error
expectValue []time.Time
givenFailFast bool
whenMust bool
}{
{
name: "nok, fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: []time.Time(nil),
expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []time.Time(nil),
expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []time.Time(nil),
expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
b.errors = tc.givenBindErrors
layout := time.RFC3339
if tc.whenLayout != "" {
layout = tc.whenLayout
}
var dest []time.Time
var err error
if tc.whenMust {
err = b.MustTimes("param", &dest, layout).BindError()
} else {
err = b.Times("param", &dest, layout).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_DurationError(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue time.Duration
givenFailFast bool
whenMust bool
}{
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: 0,
expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: 0,
expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
if tc.givenFailFast {
b.errors = []error{errors.New("previous error")}
}
var dest time.Duration
var err error
if tc.whenMust {
err = b.MustDuration("param", &dest).BindError()
} else {
err = b.Duration("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValueBinder_DurationsError(t *testing.T) {
var testCases = []struct {
name string
whenURL string
expectError string
givenBindErrors []error
expectValue []time.Duration
givenFailFast bool
whenMust bool
}{
{
name: "nok, fail fast without binding value",
givenFailFast: true,
whenURL: "/search?param=1¶m=100",
expectValue: []time.Duration(nil),
expectError: "code=400, message=failed to bind field value to Duration, err=time: missing unit in duration \"1\", field=param",
},
{
name: "nok, conversion fails, value is not changed",
whenURL: "/search?param=nope¶m=100",
expectValue: []time.Duration(nil),
expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
},
{
name: "nok (must), conversion fails, value is not changed",
whenMust: true,
whenURL: "/search?param=nope¶m=100",
expectValue: []time.Duration(nil),
expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := createTestContext(tc.whenURL, nil, nil)
b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
b.errors = tc.givenBindErrors
var dest []time.Duration
var err error
if tc.whenMust {
err = b.MustDurations("param", &dest).BindError()
} else {
err = b.Durations("param", &dest).BindError()
}
assert.Equal(t, tc.expectValue, dest)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
================================================
FILE: codecov.yml
================================================
coverage:
status:
project:
default:
threshold: 1%
patch:
default:
threshold: 1%
comment:
require_changes: true
================================================
FILE: context.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"bytes"
"encoding/xml"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"mime/multipart"
"net"
"net/http"
"net/url"
"path"
"path/filepath"
"strings"
"sync"
)
const (
// ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain.
// Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests.
// It is added to context only when Router does not find matching method handler for request.
ContextKeyHeaderAllow = "echo_header_allow"
)
const (
// defaultMemory is default value for memory limit that is used when
// parsing multipart forms (See (*http.Request).ParseMultipartForm)
defaultMemory int64 = 32 << 20 // 32 MB
indexPage = "index.html"
)
// Context represents the context of the current HTTP request. It holds request and
// response objects, path, path parameters, data and registered handler.
type Context struct {
request *http.Request
orgResponse *Response
response http.ResponseWriter
query url.Values
// formParseMaxMemory is used for http.Request.ParseMultipartForm
formParseMaxMemory int64
route *RouteInfo
pathValues *PathValues
store map[string]any
echo *Echo
logger *slog.Logger
path string
lock sync.RWMutex
}
// NewContext returns a new Context instance.
//
// Note: request,response and e can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway
// these arguments are useful when creating context for tests and cases like that.
func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context {
var e *Echo
for _, opt := range opts {
switch v := opt.(type) {
case *Echo:
e = v
}
}
return newContext(r, w, e)
}
func newContext(r *http.Request, w http.ResponseWriter, e *Echo) *Context {
c := &Context{
pathValues: nil,
store: make(map[string]any),
echo: e,
logger: nil,
}
var logger *slog.Logger
paramLen := int32(0)
formParseMaxMemory := defaultMemory
if e != nil {
paramLen = e.contextPathParamAllocSize.Load()
logger = e.Logger
formParseMaxMemory = e.formParseMaxMemory
}
if logger == nil {
logger = slog.Default()
}
c.logger = logger
p := make(PathValues, 0, paramLen)
c.pathValues = &p
c.SetRequest(r)
c.orgResponse = NewResponse(w, logger)
c.response = c.orgResponse
c.formParseMaxMemory = formParseMaxMemory
return c
}
// Reset resets the context after request completes. It must be called along
// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
// See `Echo#ServeHTTP()`
func (c *Context) Reset(r *http.Request, w http.ResponseWriter) {
c.request = r
c.orgResponse.reset(w)
c.response = c.orgResponse
c.query = nil
c.store = nil
c.logger = c.echo.Logger
c.route = nil
c.path = ""
// NOTE: empty by setting length to 0. PathValues has to have capacity of c.echo.contextPathParamAllocSize at all times
*c.pathValues = (*c.pathValues)[:0]
}
func (c *Context) writeContentType(value string) {
header := c.response.Header()
if header.Get(HeaderContentType) == "" {
header.Set(HeaderContentType, value)
}
}
// Request returns `*http.Request`.
func (c *Context) Request() *http.Request {
return c.request
}
// SetRequest sets `*http.Request`.
func (c *Context) SetRequest(r *http.Request) {
c.request = r
}
// Response returns `*Response`.
func (c *Context) Response() http.ResponseWriter {
return c.response
}
// SetResponse sets `*http.ResponseWriter`. Some context methods and/or middleware require that given ResponseWriter implements following
// method `Unwrap() http.ResponseWriter` which eventually should return *echo.Response instance.
func (c *Context) SetResponse(r http.ResponseWriter) {
c.response = r
}
// IsTLS returns true if HTTP connection is TLS otherwise false.
func (c *Context) IsTLS() bool {
return c.request.TLS != nil
}
// IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
func (c *Context) IsWebSocket() bool {
upgrade := c.request.Header.Get(HeaderUpgrade)
connection := c.request.Header.Get(HeaderConnection)
return strings.EqualFold(upgrade, "websocket") && strings.Contains(strings.ToLower(connection), "upgrade")
}
// Scheme returns the HTTP protocol scheme, `http` or `https`.
func (c *Context) Scheme() string {
// Can't use `r.Request.URL.Scheme`
// See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0
if c.IsTLS() {
return "https"
}
if scheme := c.request.Header.Get(HeaderXForwardedProto); scheme != "" {
return scheme
}
if scheme := c.request.Header.Get(HeaderXForwardedProtocol); scheme != "" {
return scheme
}
if ssl := c.request.Header.Get(HeaderXForwardedSsl); ssl == "on" {
return "https"
}
if scheme := c.request.Header.Get(HeaderXUrlScheme); scheme != "" {
return scheme
}
return "http"
}
// RealIP returns the client's network address based on `X-Forwarded-For`
// or `X-Real-IP` request header.
// The behavior can be configured using `Echo#IPExtractor`.
func (c *Context) RealIP() string {
if c.echo != nil && c.echo.IPExtractor != nil {
return c.echo.IPExtractor(c.request)
}
// Fall back to legacy behavior
if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" {
i := strings.IndexAny(ip, ",")
if i > 0 {
xffip := strings.TrimSpace(ip[:i])
xffip = strings.TrimPrefix(xffip, "[")
xffip = strings.TrimSuffix(xffip, "]")
return xffip
}
return ip
}
if ip := c.request.Header.Get(HeaderXRealIP); ip != "" {
ip = strings.TrimPrefix(ip, "[")
ip = strings.TrimSuffix(ip, "]")
return ip
}
ra, _, _ := net.SplitHostPort(c.request.RemoteAddr)
return ra
}
// Path returns the registered path for the handler.
func (c *Context) Path() string {
return c.path
}
// SetPath sets the registered path for the handler.
func (c *Context) SetPath(p string) {
c.path = p
}
// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route.
//
// RouteInfo returns generic "empty" struct for these cases:
// * Context is accessed before Routing is done. For example inside Pre middlewares (`e.Pre()`)
// * Router did not find matching route - 404 (route not found)
// * Router did not find matching route with same method - 405 (method not allowed)
func (c *Context) RouteInfo() RouteInfo {
if c.route != nil {
return c.route.Clone()
}
return RouteInfo{}
}
// Param returns path parameter by name.
func (c *Context) Param(name string) string {
return c.pathValues.GetOr(name, "")
}
// ParamOr returns the path parameter or default value for the provided name.
//
// Notes for DefaultRouter implementation:
// Path parameter could be empty for cases like that:
// * route `/release-:version/bin` and request URL is `/release-/bin`
// * route `/api/:version/image.jpg` and request URL is `/api//image.jpg`
// but not when path parameter is last part of route path
// * route `/download/file.:ext` will not match request `/download/file.`
func (c *Context) ParamOr(name, defaultValue string) string {
return c.pathValues.GetOr(name, defaultValue)
}
// PathValues returns path parameter values.
func (c *Context) PathValues() PathValues {
return *c.pathValues
}
// SetPathValues sets path parameters for current request.
func (c *Context) SetPathValues(pathValues PathValues) {
if pathValues == nil {
panic("context SetPathValues called with nil PathValues")
}
c.setPathValues(&pathValues)
}
// InitializeRoute sets the route related variables of this request to the context.
func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues) {
c.route = ri
c.path = ri.Path
c.setPathValues(pathValues)
}
func (c *Context) setPathValues(pv *PathValues) {
// Router accesses c.pathValues by index and may resize it to full capacity during routing
// for that to work without going out-of-bounds we must make sure that c.pathValues slice is not replaced with smaller
// slice than Router can set when routing Route with maximum amount of parameters.
pathValues := c.pathValues
if cap(*c.pathValues) < len(*pv) {
// normally we should not end up here. pathValues is normally sized to Echo.contextPathParamAllocSize which should not
// be smaller than anything router knows as maximum path parameter count to be.
tmp := make(PathValues, len(*pv))
c.pathValues = &tmp
pathValues = c.pathValues
} else if len(*c.pathValues) != len(*pv) {
*pathValues = (*pathValues)[0:len(*pv)] // resize slice to given params length for copy to work
}
copy(*pathValues, *pv)
}
// QueryParam returns the query param for the provided name.
func (c *Context) QueryParam(name string) string {
if c.query == nil {
c.query = c.request.URL.Query()
}
return c.query.Get(name)
}
// QueryParamOr returns the query param or default value for the provided name.
// Note: QueryParamOr does not distinguish if query had no value by that name or value was empty string
// This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamOr("search", "1")`
func (c *Context) QueryParamOr(name, defaultValue string) string {
value := c.QueryParam(name)
if value == "" {
value = defaultValue
}
return value
}
// QueryParams returns the query parameters as `url.Values`.
func (c *Context) QueryParams() url.Values {
if c.query == nil {
c.query = c.request.URL.Query()
}
return c.query
}
// QueryString returns the URL query string.
func (c *Context) QueryString() string {
return c.request.URL.RawQuery
}
// FormValue returns the form field value for the provided name.
func (c *Context) FormValue(name string) string {
return c.request.FormValue(name)
}
// FormValueOr returns the form field value or default value for the provided name.
// Note: FormValueOr does not distinguish if form had no value by that name or value was empty string
func (c *Context) FormValueOr(name, defaultValue string) string {
value := c.FormValue(name)
if value == "" {
value = defaultValue
}
return value
}
// FormValues returns the form field values as `url.Values`.
func (c *Context) FormValues() (url.Values, error) {
if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) {
if err := c.request.ParseMultipartForm(c.formParseMaxMemory); err != nil {
return nil, err
}
} else {
if err := c.request.ParseForm(); err != nil {
return nil, err
}
}
return c.request.Form, nil
}
// FormFile returns the multipart form file for the provided name.
func (c *Context) FormFile(name string) (*multipart.FileHeader, error) {
f, fh, err := c.request.FormFile(name)
if err != nil {
return nil, err
}
_ = f.Close()
return fh, nil
}
// MultipartForm returns the multipart form.
func (c *Context) MultipartForm() (*multipart.Form, error) {
err := c.request.ParseMultipartForm(c.formParseMaxMemory)
return c.request.MultipartForm, err
}
// Cookie returns the named cookie provided in the request.
func (c *Context) Cookie(name string) (*http.Cookie, error) {
return c.request.Cookie(name)
}
// SetCookie adds a `Set-Cookie` header in HTTP response.
func (c *Context) SetCookie(cookie *http.Cookie) {
http.SetCookie(c.Response(), cookie)
}
// Cookies returns the HTTP cookies sent with the request.
func (c *Context) Cookies() []*http.Cookie {
return c.request.Cookies()
}
// Get retrieves data from the context.
// Method returns any(nil) when key does not exist which is different from typed nil (eg. []byte(nil)).
func (c *Context) Get(key string) any {
c.lock.RLock()
defer c.lock.RUnlock()
return c.store[key]
}
// Set saves data in the context.
func (c *Context) Set(key string, val any) {
c.lock.Lock()
defer c.lock.Unlock()
if c.store == nil {
c.store = make(map[string]any)
}
c.store[key] = val
}
// Bind binds path params, query params and the request body into provided type `i`. The default binder
// binds body based on Content-Type header.
func (c *Context) Bind(i any) error {
return c.echo.Binder.Bind(c, i)
}
// Validate validates provided `i`. It is usually called after `Context#Bind()`.
// Validator must be registered using `Echo#Validator`.
func (c *Context) Validate(i any) error {
if c.echo.Validator == nil {
return ErrValidatorNotRegistered
}
return c.echo.Validator.Validate(i)
}
// Render renders a template with data and sends a text/html response with status
// code. Renderer must be registered using `Echo.Renderer`.
func (c *Context) Render(code int, name string, data any) (err error) {
if c.echo.Renderer == nil {
return ErrRendererNotRegistered
}
// as Renderer.Render can fail, and in that case we need to delay sending status code to the client until
// (global) error handler decides the correct status code for the error to be sent to the client, so we need to write
// the rendered template to the buffer first.
//
// html.Template.ExecuteTemplate() documentations writes:
// > If an error occurs executing the template or writing its output,
// > execution stops, but partial results may already have been written to
// > the output writer.
buf := new(bytes.Buffer)
if err = c.echo.Renderer.Render(c, buf, name, data); err != nil {
return
}
return c.HTMLBlob(code, buf.Bytes())
}
// HTML sends an HTTP response with status code.
func (c *Context) HTML(code int, html string) (err error) {
return c.HTMLBlob(code, []byte(html))
}
// HTMLBlob sends an HTTP blob response with status code.
func (c *Context) HTMLBlob(code int, b []byte) (err error) {
return c.Blob(code, MIMETextHTMLCharsetUTF8, b)
}
// String sends a string response with status code.
func (c *Context) String(code int, s string) (err error) {
return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s))
}
func (c *Context) jsonPBlob(code int, callback string, i any) (err error) {
c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8)
c.response.WriteHeader(code)
if _, err = c.response.Write([]byte(callback + "(")); err != nil {
return
}
if err = c.echo.JSONSerializer.Serialize(c, i, ""); err != nil {
return
}
if _, err = c.response.Write([]byte(");")); err != nil {
return
}
return
}
func (c *Context) json(code int, i any, indent string) error {
c.writeContentType(MIMEApplicationJSON)
// as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until
// (global) error handler decides correct status code for the error to be sent to the client.
// For that we need to use writer that can store the proposed status code until the first Write is called.
if r, err := UnwrapResponse(c.response); err == nil {
r.Status = code
} else {
resp := c.Response()
c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code})
defer c.SetResponse(resp)
}
return c.echo.JSONSerializer.Serialize(c, i, indent)
}
// JSON sends a JSON response with status code.
func (c *Context) JSON(code int, i any) (err error) {
return c.json(code, i, "")
}
// JSONPretty sends a pretty-print JSON with status code.
func (c *Context) JSONPretty(code int, i any, indent string) (err error) {
return c.json(code, i, indent)
}
// JSONBlob sends a JSON blob response with status code.
func (c *Context) JSONBlob(code int, b []byte) (err error) {
return c.Blob(code, MIMEApplicationJSON, b)
}
// JSONP sends a JSONP response with status code. It uses `callback` to construct
// the JSONP payload.
func (c *Context) JSONP(code int, callback string, i any) (err error) {
return c.jsonPBlob(code, callback, i)
}
// JSONPBlob sends a JSONP blob response with status code. It uses `callback`
// to construct the JSONP payload.
func (c *Context) JSONPBlob(code int, callback string, b []byte) (err error) {
c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8)
c.response.WriteHeader(code)
if _, err = c.response.Write([]byte(callback + "(")); err != nil {
return
}
if _, err = c.response.Write(b); err != nil {
return
}
_, err = c.response.Write([]byte(");"))
return
}
func (c *Context) xml(code int, i any, indent string) (err error) {
c.writeContentType(MIMEApplicationXMLCharsetUTF8)
c.response.WriteHeader(code)
enc := xml.NewEncoder(c.response)
if indent != "" {
enc.Indent("", indent)
}
if _, err = c.response.Write([]byte(xml.Header)); err != nil {
return
}
return enc.Encode(i)
}
// XML sends an XML response with status code.
func (c *Context) XML(code int, i any) (err error) {
return c.xml(code, i, "")
}
// XMLPretty sends a pretty-print XML with status code.
func (c *Context) XMLPretty(code int, i any, indent string) (err error) {
return c.xml(code, i, indent)
}
// XMLBlob sends an XML blob response with status code.
func (c *Context) XMLBlob(code int, b []byte) (err error) {
c.writeContentType(MIMEApplicationXMLCharsetUTF8)
c.response.WriteHeader(code)
if _, err = c.response.Write([]byte(xml.Header)); err != nil {
return
}
_, err = c.response.Write(b)
return
}
// Blob sends a blob response with status code and content type.
func (c *Context) Blob(code int, contentType string, b []byte) (err error) {
c.writeContentType(contentType)
c.response.WriteHeader(code)
_, err = c.response.Write(b)
return
}
// Stream sends a streaming response with status code and content type.
func (c *Context) Stream(code int, contentType string, r io.Reader) (err error) {
c.writeContentType(contentType)
c.response.WriteHeader(code)
_, err = io.Copy(c.response, r)
return
}
// File sends a response with the content of the file.
func (c *Context) File(file string) error {
return fsFile(c, file, c.echo.Filesystem)
}
// FileFS serves file from given file system.
//
// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
// including `assets/images` as their prefix.
func (c *Context) FileFS(file string, filesystem fs.FS) error {
return fsFile(c, file, filesystem)
}
func fsFile(c *Context, file string, filesystem fs.FS) error {
file = path.Clean(file) // `os.Open` and `os.DirFs.Open()` behave differently, later does not like ``, `.`, `..` at all, but we allowed those now need to clean
f, err := filesystem.Open(file)
if err != nil {
return ErrNotFound
}
defer f.Close()
fi, _ := f.Stat()
if fi.IsDir() {
file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect.
f, err = filesystem.Open(file)
if err != nil {
return ErrNotFound
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
return err
}
}
ff, ok := f.(io.ReadSeeker)
if !ok {
return errors.New("file does not implement io.ReadSeeker")
}
http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
return nil
}
// Attachment sends a response as attachment, prompting client to save the file.
func (c *Context) Attachment(file, name string) error {
return c.contentDisposition(file, name, "attachment")
}
// Inline sends a response as inline, opening the file in the browser.
func (c *Context) Inline(file, name string) error {
return c.contentDisposition(file, name, "inline")
}
var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
func (c *Context) contentDisposition(file, name, dispositionType string) error {
c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name)))
return c.File(file)
}
// NoContent sends a response with no body and a status code.
func (c *Context) NoContent(code int) error {
c.response.WriteHeader(code)
return nil
}
// Redirect redirects the request to a provided URL with status code.
func (c *Context) Redirect(code int, url string) error {
if code < 300 || code > 308 {
return ErrInvalidRedirectCode
}
c.response.Header().Set(HeaderLocation, url)
c.response.WriteHeader(code)
return nil
}
// Logger returns logger in Context
func (c *Context) Logger() *slog.Logger {
if c.logger != nil {
return c.logger
}
return c.echo.Logger
}
// SetLogger sets logger in Context
func (c *Context) SetLogger(logger *slog.Logger) {
c.logger = logger
}
// Echo returns the `Echo` instance.
func (c *Context) Echo() *Echo {
return c.echo
}
================================================
FILE: context_generic.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import "errors"
// ErrNonExistentKey is error that is returned when key does not exist
var ErrNonExistentKey = errors.New("non existent key")
// ErrInvalidKeyType is error that is returned when the value is not castable to expected type.
var ErrInvalidKeyType = errors.New("invalid key type")
// ContextGet retrieves a value from the context store or ErrNonExistentKey error the key is missing.
// Returns ErrInvalidKeyType error if the value is not castable to type T.
func ContextGet[T any](c *Context, key string) (T, error) {
c.lock.RLock()
defer c.lock.RUnlock()
val, ok := c.store[key]
if !ok {
var zero T
return zero, ErrNonExistentKey
}
typed, ok := val.(T)
if !ok {
var zero T
return zero, ErrInvalidKeyType
}
return typed, nil
}
// ContextGetOr retrieves a value from the context store or returns a default value when the key
// is missing. Returns ErrInvalidKeyType error if the value is not castable to type T.
func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error) {
typed, err := ContextGet[T](c, key)
if err == ErrNonExistentKey {
return defaultValue, nil
}
return typed, err
}
================================================
FILE: context_generic_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestContextGetOK(t *testing.T) {
c := NewContext(nil, nil)
c.Set("key", int64(123))
v, err := ContextGet[int64](c, "key")
assert.NoError(t, err)
assert.Equal(t, int64(123), v)
}
func TestContextGetNonExistentKey(t *testing.T) {
c := NewContext(nil, nil)
c.Set("key", int64(123))
v, err := ContextGet[int64](c, "nope")
assert.ErrorIs(t, err, ErrNonExistentKey)
assert.Equal(t, int64(0), v)
}
func TestContextGetInvalidCast(t *testing.T) {
c := NewContext(nil, nil)
c.Set("key", int64(123))
v, err := ContextGet[bool](c, "key")
assert.ErrorIs(t, err, ErrInvalidKeyType)
assert.Equal(t, false, v)
}
func TestContextGetOrOK(t *testing.T) {
c := NewContext(nil, nil)
c.Set("key", int64(123))
v, err := ContextGetOr[int64](c, "key", 999)
assert.NoError(t, err)
assert.Equal(t, int64(123), v)
}
func TestContextGetOrNonExistentKey(t *testing.T) {
c := NewContext(nil, nil)
c.Set("key", int64(123))
v, err := ContextGetOr[int64](c, "nope", 999)
assert.NoError(t, err)
assert.Equal(t, int64(999), v)
}
func TestContextGetOrInvalidCast(t *testing.T) {
c := NewContext(nil, nil)
c.Set("key", int64(123))
v, err := ContextGetOr[float32](c, "key", float32(999))
assert.ErrorIs(t, err, ErrInvalidKeyType)
assert.Equal(t, float32(0), v)
}
================================================
FILE: context_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"bytes"
"crypto/tls"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"io/fs"
"log/slog"
"math"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"text/template"
"time"
"github.com/stretchr/testify/assert"
)
type Template struct {
templates *template.Template
}
var testUser = user{ID: 1, Name: "Jon Snow"}
func BenchmarkAllocJSONP(b *testing.B) {
e := New()
e.Logger = slog.New(slog.DiscardHandler)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
c.JSONP(http.StatusOK, "callback", testUser)
}
}
func BenchmarkAllocJSON(b *testing.B) {
e := New()
e.Logger = slog.New(slog.DiscardHandler)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
c.JSON(http.StatusOK, testUser)
}
}
func BenchmarkAllocXML(b *testing.B) {
e := New()
e.Logger = slog.New(slog.DiscardHandler)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
c.XML(http.StatusOK, testUser)
}
}
func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) {
c := Context{request: &http.Request{
Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}},
}}
for i := 0; i < b.N; i++ {
c.RealIP()
}
}
func (t *Template) Render(c *Context, w io.Writer, name string, data any) error {
return t.templates.ExecuteTemplate(w, name, data)
}
func TestContextEcho(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
assert.Equal(t, e, c.Echo())
}
func TestContextRequest(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
assert.NotNil(t, c.Request())
assert.Equal(t, req, c.Request())
}
func TestContextResponse(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
assert.NotNil(t, c.Response())
}
func TestContextRenderTemplate(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
tmpl := &Template{
templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
}
c.Echo().Renderer = tmpl
err := c.Render(http.StatusOK, "hello", "Jon Snow")
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Hello, Jon Snow!", rec.Body.String())
}
}
func TestContextRenderTemplateError(t *testing.T) {
// we test that when template rendering fails, no response is sent to the client yet, so the global error handler can decide what to do
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
tmpl := &Template{
templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
}
c.Echo().Renderer = tmpl
err := c.Render(http.StatusOK, "not_existing", "Jon Snow")
assert.EqualError(t, err, `template: no template "not_existing" associated with template "hello"`)
assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
assert.Empty(t, rec.Body.String()) // body must not be sent to the client
}
func TestContextRenderErrorsOnNoRenderer(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Echo().Renderer = nil
assert.Error(t, c.Render(http.StatusOK, "hello", "Jon Snow"))
}
func TestContextStream(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
r, w := io.Pipe()
go func() {
defer w.Close()
for i := 0; i < 3; i++ {
fmt.Fprintf(w, "data: index %v\n\n", i)
time.Sleep(5 * time.Millisecond)
}
}()
err := c.Stream(http.StatusOK, "text/event-stream", r)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "text/event-stream", rec.Header().Get(HeaderContentType))
assert.Equal(t, "data: index 0\n\ndata: index 1\n\ndata: index 2\n\n", rec.Body.String())
}
}
func TestContextHTML(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := NewContext(req, rec)
err := c.HTML(http.StatusOK, "Hi, Jon Snow")
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, "Hi, Jon Snow", rec.Body.String())
}
}
func TestContextHTMLBlob(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := NewContext(req, rec)
err := c.HTMLBlob(http.StatusOK, []byte("Hi, Jon Snow"))
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, "Hi, Jon Snow", rec.Body.String())
}
}
func TestContextJSON(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
c := e.NewContext(req, rec)
err := c.JSON(http.StatusOK, user{ID: 1, Name: "Jon Snow"})
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
assert.Equal(t, userJSON+"\n", rec.Body.String())
}
}
func TestContextJSONErrorsOut(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
c := e.NewContext(req, rec)
err := c.JSON(http.StatusOK, make(chan bool))
assert.EqualError(t, err, "json: unsupported type: chan bool")
assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
assert.Empty(t, rec.Body.String()) // body must not be sent to the client
}
func TestContextJSONWithNotEchoResponse(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
c := e.NewContext(req, rec)
c.SetResponse(rec)
err := c.JSON(http.StatusCreated, map[string]float64{"foo": math.NaN()})
assert.EqualError(t, err, "json: unsupported value: NaN")
assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
assert.Empty(t, rec.Body.String()) // body must not be sent to the client
}
func TestContextJSONPretty(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ")
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
}
}
func TestContextJSONWithEmptyIntent(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
u := user{ID: 1, Name: "Jon Snow"}
emptyIndent := ""
buf := new(bytes.Buffer)
enc := json.NewEncoder(buf)
enc.SetIndent(emptyIndent, emptyIndent)
_ = enc.Encode(u)
err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
assert.Equal(t, buf.String(), rec.Body.String())
}
}
func TestContextJSONP(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
callback := "callback"
err := c.JSONP(http.StatusOK, callback, user{ID: 1, Name: "Jon Snow"})
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String())
}
}
func TestContextJSONBlob(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"})
assert.NoError(t, err)
err = c.JSONBlob(http.StatusOK, data)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
assert.Equal(t, userJSON, rec.Body.String())
}
}
func TestContextJSONPBlob(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
callback := "callback"
data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"})
assert.NoError(t, err)
err = c.JSONPBlob(http.StatusOK, callback, data)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, callback+"("+userJSON+");", rec.Body.String())
}
}
func TestContextXML(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
err := c.XML(http.StatusOK, user{ID: 1, Name: "Jon Snow"})
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, xml.Header+userXML, rec.Body.String())
}
}
func TestContextXMLPretty(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ")
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
}
}
func TestContextXMLBlob(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
data, err := xml.Marshal(user{ID: 1, Name: "Jon Snow"})
assert.NoError(t, err)
err = c.XMLBlob(http.StatusOK, data)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, xml.Header+userXML, rec.Body.String())
}
}
func TestContextXMLWithEmptyIntent(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
u := user{ID: 1, Name: "Jon Snow"}
emptyIndent := ""
buf := new(bytes.Buffer)
enc := xml.NewEncoder(buf)
enc.Indent(emptyIndent, emptyIndent)
_ = enc.Encode(u)
err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, xml.Header+buf.String(), rec.Body.String())
}
}
func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := c.JSON(http.StatusCreated, user{ID: 1, Name: "Jon Snow"})
if assert.NoError(t, err) {
assert.Equal(t, http.StatusCreated, rec.Code)
assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
assert.Equal(t, userJSON+"\n", rec.Body.String())
}
}
func TestContextAttachment(t *testing.T) {
var testCases = []struct {
name string
whenName string
expectHeader string
}{
{
name: "ok",
whenName: "walle.png",
expectHeader: `attachment; filename="walle.png"`,
},
{
name: "ok, escape quotes in malicious filename",
whenName: `malicious.sh"; \"; dummy=.txt`,
expectHeader: `attachment; filename="malicious.sh\"; \\\"; dummy=.txt"`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
err := c.Attachment("_fixture/images/walle.png", tc.whenName)
if assert.NoError(t, err) {
assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition))
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, 219885, rec.Body.Len())
}
})
}
}
func TestContextInline(t *testing.T) {
var testCases = []struct {
name string
whenName string
expectHeader string
}{
{
name: "ok",
whenName: "walle.png",
expectHeader: `inline; filename="walle.png"`,
},
{
name: "ok, escape quotes in malicious filename",
whenName: `malicious.sh"; \"; dummy=.txt`,
expectHeader: `inline; filename="malicious.sh\"; \\\"; dummy=.txt"`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
err := c.Inline("_fixture/images/walle.png", tc.whenName)
if assert.NoError(t, err) {
assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition))
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, 219885, rec.Body.Len())
}
})
}
}
func TestContextNoContent(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
c := e.NewContext(req, rec)
c.NoContent(http.StatusOK)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestContextCookie(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
theme := "theme=light"
user := "user=Jon Snow"
req.Header.Add(HeaderCookie, theme)
req.Header.Add(HeaderCookie, user)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Read single
cookie, err := c.Cookie("theme")
if assert.NoError(t, err) {
assert.Equal(t, "theme", cookie.Name)
assert.Equal(t, "light", cookie.Value)
}
// Read multiple
for _, cookie := range c.Cookies() {
switch cookie.Name {
case "theme":
assert.Equal(t, "light", cookie.Value)
case "user":
assert.Equal(t, "Jon Snow", cookie.Value)
}
}
// Write
cookie = &http.Cookie{
Name: "SSID",
Value: "Ap4PGTEq",
Domain: "labstack.com",
Path: "/",
Expires: time.Now(),
Secure: true,
HttpOnly: true,
}
c.SetCookie(cookie)
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID")
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com")
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure")
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly")
}
func TestContext_PathValues(t *testing.T) {
var testCases = []struct {
name string
given PathValues
expect PathValues
}{
{
name: "param exists",
given: PathValues{
{Name: "uid", Value: "101"},
{Name: "fid", Value: "501"},
},
expect: PathValues{
{Name: "uid", Value: "101"},
{Name: "fid", Value: "501"},
},
},
{
name: "params is empty",
given: PathValues{},
expect: PathValues{},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, nil)
c.SetPathValues(tc.given)
assert.EqualValues(t, tc.expect, c.PathValues())
})
}
}
func TestContext_PathParam(t *testing.T) {
var testCases = []struct {
name string
given PathValues
whenParamName string
expect string
}{
{
name: "param exists",
given: PathValues{
{Name: "uid", Value: "101"},
{Name: "fid", Value: "501"},
},
whenParamName: "uid",
expect: "101",
},
{
name: "multiple same param values exists - return first",
given: PathValues{
{Name: "uid", Value: "101"},
{Name: "uid", Value: "202"},
{Name: "fid", Value: "501"},
},
whenParamName: "uid",
expect: "101",
},
{
name: "param does not exists",
given: PathValues{
{Name: "uid", Value: "101"},
},
whenParamName: "nope",
expect: "",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, nil)
c.SetPathValues(tc.given)
assert.EqualValues(t, tc.expect, c.Param(tc.whenParamName))
})
}
}
func TestContext_PathParamDefault(t *testing.T) {
var testCases = []struct {
name string
given PathValues
whenParamName string
whenDefaultValue string
expect string
}{
{
name: "param exists",
given: PathValues{
{Name: "uid", Value: "101"},
{Name: "fid", Value: "501"},
},
whenParamName: "uid",
whenDefaultValue: "999",
expect: "101",
},
{
name: "param exists and is empty",
given: PathValues{
{Name: "uid", Value: ""},
{Name: "fid", Value: "501"},
},
whenParamName: "uid",
whenDefaultValue: "999",
expect: "", // <-- this is different from QueryParamOr behaviour
},
{
name: "param does not exists",
given: PathValues{
{Name: "uid", Value: "101"},
},
whenParamName: "nope",
whenDefaultValue: "999",
expect: "999",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, nil)
c.SetPathValues(tc.given)
assert.EqualValues(t, tc.expect, c.ParamOr(tc.whenParamName, tc.whenDefaultValue))
})
}
}
func TestContextGetAndSetPathValuesMutability(t *testing.T) {
t.Run("c.PathValues() does not return copy and modifying raw slice mutates value in context", func(t *testing.T) {
e := New()
e.contextPathParamAllocSize.Store(1)
req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
c := e.NewContext(req, nil)
params := PathValues{{Name: "foo", Value: "101"}}
c.SetPathValues(params)
// round-trip param values with modification
paramVals := c.PathValues()
assert.Equal(t, params, c.PathValues())
// PathValues() does not return copy and modifying raw slice mutates value in context
paramVals[0] = PathValue{Name: "xxx", Value: "yyy"}
assert.Equal(t, PathValues{PathValue{Name: "xxx", Value: "yyy"}}, c.PathValues())
})
t.Run("calling SetPathValues with bigger size changes capacity in context", func(t *testing.T) {
e := New()
e.contextPathParamAllocSize.Store(1)
req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
c := e.NewContext(req, nil)
// increase path param capacity in context
pathValues := PathValues{
{Name: "aaa", Value: "bbb"},
{Name: "ccc", Value: "ddd"},
}
c.SetPathValues(pathValues)
assert.Equal(t, pathValues, c.PathValues())
// shouldn't explode during Reset() afterwards!
assert.NotPanics(t, func() {
c.Reset(nil, nil)
})
assert.Equal(t, PathValues{}, c.PathValues())
assert.Len(t, *c.pathValues, 0)
assert.Equal(t, 2, cap(*c.pathValues))
})
t.Run("calling SetPathValues with smaller size slice does not change capacity in context", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
c := e.NewContext(req, nil)
c.pathValues = &PathValues{
{Name: "aaa", Value: "bbb"},
{Name: "ccc", Value: "ddd"},
}
pathValues := PathValues{
{Name: "aaa", Value: "bbb"},
}
// given pathValues slice is smaller. this should not decrease c.pathValues capacity
c.SetPathValues(pathValues)
assert.Equal(t, pathValues, c.PathValues())
// shouldn't explode during Reset() afterwards!
assert.NotPanics(t, func() {
c.Reset(nil, nil)
})
assert.Equal(t, PathValues{}, c.PathValues())
assert.Len(t, *c.pathValues, 0)
assert.Equal(t, 2, cap(*c.pathValues))
})
}
// Issue #1655
func TestContext_SetParamNamesShouldNotModifyPathValuesCapacity(t *testing.T) {
e := New()
c := e.NewContext(nil, nil)
assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load())
expectedTwoParams := PathValues{
{Name: "1", Value: "one"},
{Name: "2", Value: "two"},
}
c.SetPathValues(expectedTwoParams)
assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load())
assert.Equal(t, expectedTwoParams, c.PathValues())
expectedThreeParams := PathValues{
{Name: "1", Value: "one"},
{Name: "2", Value: "two"},
{Name: "3", Value: "three"},
}
c.SetPathValues(expectedThreeParams)
assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load())
assert.Equal(t, expectedThreeParams, c.PathValues())
}
func TestContextFormValue(t *testing.T) {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("email", "jon@labstack.com")
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(HeaderContentType, MIMEApplicationForm)
c := e.NewContext(req, nil)
// FormValue
assert.Equal(t, "Jon Snow", c.FormValue("name"))
assert.Equal(t, "jon@labstack.com", c.FormValue("email"))
// FormValueOr
assert.Equal(t, "Jon Snow", c.FormValueOr("name", "nope"))
assert.Equal(t, "default", c.FormValueOr("missing", "default"))
// FormValues
values, err := c.FormValues()
if assert.NoError(t, err) {
assert.Equal(t, url.Values{
"name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"},
}, values)
}
// Multipart FormParams error
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(HeaderContentType, MIMEMultipartForm)
c = e.NewContext(req, nil)
values, err = c.FormValues()
assert.Nil(t, values)
assert.Error(t, err)
}
func TestContext_QueryParams(t *testing.T) {
var testCases = []struct {
expect url.Values
name string
givenURL string
}{
{
name: "multiple values in url",
givenURL: "/?test=1&test=2&email=jon%40labstack.com",
expect: url.Values{
"test": []string{"1", "2"},
"email": []string{"jon@labstack.com"},
},
},
{
name: "single value in url",
givenURL: "/?nope=1",
expect: url.Values{
"nope": []string{"1"},
},
},
{
name: "no query params in url",
givenURL: "/?",
expect: url.Values{},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
e := New()
c := e.NewContext(req, nil)
assert.Equal(t, tc.expect, c.QueryParams())
})
}
}
func TestContext_QueryParam(t *testing.T) {
var testCases = []struct {
name string
givenURL string
whenParamName string
expect string
}{
{
name: "value exists in url",
givenURL: "/?test=1",
whenParamName: "test",
expect: "1",
},
{
name: "multiple values exists in url",
givenURL: "/?test=9&test=8",
whenParamName: "test",
expect: "9", // <-- first value in returned
},
{
name: "value does not exists in url",
givenURL: "/?nope=1",
whenParamName: "test",
expect: "",
},
{
name: "value is empty in url",
givenURL: "/?test=",
whenParamName: "test",
expect: "",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
e := New()
c := e.NewContext(req, nil)
assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName))
})
}
}
func TestContext_QueryParamDefault(t *testing.T) {
var testCases = []struct {
name string
givenURL string
whenParamName string
whenDefaultValue string
expect string
}{
{
name: "value exists in url",
givenURL: "/?test=1",
whenParamName: "test",
whenDefaultValue: "999",
expect: "1",
},
{
name: "value does not exists in url",
givenURL: "/?nope=1",
whenParamName: "test",
whenDefaultValue: "999",
expect: "999",
},
{
name: "value is empty in url",
givenURL: "/?test=",
whenParamName: "test",
whenDefaultValue: "999",
expect: "999",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
e := New()
c := e.NewContext(req, nil)
assert.Equal(t, tc.expect, c.QueryParamOr(tc.whenParamName, tc.whenDefaultValue))
})
}
}
func TestContextFormFile(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
mr := multipart.NewWriter(buf)
w, err := mr.CreateFormFile("file", "test")
if assert.NoError(t, err) {
w.Write([]byte("test"))
}
mr.Close()
req := httptest.NewRequest(http.MethodPost, "/", buf)
req.Header.Set(HeaderContentType, mr.FormDataContentType())
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
f, err := c.FormFile("file")
if assert.NoError(t, err) {
assert.Equal(t, "test", f.Filename)
}
}
func TestContextMultipartForm(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
mw := multipart.NewWriter(buf)
mw.WriteField("name", "Jon Snow")
fileContent := "This is a test file"
w, err := mw.CreateFormFile("file", "test.txt")
if assert.NoError(t, err) {
w.Write([]byte(fileContent))
}
mw.Close()
req := httptest.NewRequest(http.MethodPost, "/", buf)
req.Header.Set(HeaderContentType, mw.FormDataContentType())
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
f, err := c.MultipartForm()
if assert.NoError(t, err) {
assert.NotNil(t, f)
files := f.File["file"]
if assert.Len(t, files, 1) {
file := files[0]
assert.Equal(t, "test.txt", file.Filename)
assert.Equal(t, int64(len(fileContent)), file.Size)
}
}
}
func TestContextRedirect(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
assert.Equal(t, http.StatusMovedPermanently, rec.Code)
assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
}
func TestContextGet(t *testing.T) {
var testCases = []struct {
name string
given any
whenKey string
expect any
}{
{
name: "ok, value exist",
given: "Jon Snow",
whenKey: "key",
expect: "Jon Snow",
},
{
name: "ok, value does not exist",
given: "Jon Snow",
whenKey: "nope",
expect: nil,
},
{
name: "ok, value is nil value",
given: []byte(nil),
whenKey: "key",
expect: []byte(nil),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var c = new(Context)
c.Set("key", tc.given)
v := c.Get(tc.whenKey)
assert.Equal(t, tc.expect, v)
})
}
}
func BenchmarkContext_Store(b *testing.B) {
e := &Echo{}
c := &Context{
echo: e,
}
for n := 0; n < b.N; n++ {
c.Set("name", "Jon Snow")
if c.Get("name") != "Jon Snow" {
b.Fail()
}
}
}
type validator struct{}
func (*validator) Validate(i any) error {
return nil
}
func TestContext_Validate(t *testing.T) {
e := New()
c := e.NewContext(nil, nil)
assert.Error(t, c.Validate(struct{}{}))
e.Validator = &validator{}
assert.NoError(t, c.Validate(struct{}{}))
}
func TestContext_QueryString(t *testing.T) {
e := New()
queryString := "query=string&var=val"
req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil)
c := e.NewContext(req, nil)
assert.Equal(t, queryString, c.QueryString())
}
func TestContext_Request(t *testing.T) {
var c = new(Context)
assert.Nil(t, c.Request())
req := httptest.NewRequest(http.MethodGet, "/path", nil)
c.SetRequest(req)
assert.Equal(t, req, c.Request())
}
func TestContext_Scheme(t *testing.T) {
tests := []struct {
c *Context
s string
}{
{
&Context{
request: &http.Request{
TLS: &tls.ConnectionState{},
},
},
"https",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedProto: []string{"https"}},
},
},
"https",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedProtocol: []string{"http"}},
},
},
"http",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedSsl: []string{"on"}},
},
},
"https",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXUrlScheme: []string{"https"}},
},
},
"https",
},
{
&Context{
request: &http.Request{},
},
"http",
},
}
for _, tt := range tests {
assert.Equal(t, tt.s, tt.c.Scheme())
}
}
func TestContext_IsWebSocket(t *testing.T) {
tests := []struct {
c *Context
ws assert.BoolAssertionFunc
}{
{
&Context{
request: &http.Request{
Header: http.Header{
HeaderUpgrade: []string{"websocket"},
HeaderConnection: []string{"upgrade"},
},
},
},
assert.True,
},
{
&Context{
request: &http.Request{
Header: http.Header{
HeaderUpgrade: []string{"Websocket"},
HeaderConnection: []string{"Upgrade"},
},
},
},
assert.True,
},
{
&Context{
request: &http.Request{},
},
assert.False,
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderUpgrade: []string{"other"}},
},
},
assert.False,
},
{
&Context{
request: &http.Request{
Header: http.Header{
HeaderUpgrade: []string{"websocket"},
HeaderConnection: []string{"close"},
},
},
},
assert.False,
},
}
for i, tt := range tests {
t.Run(fmt.Sprintf("test %d", i+1), func(t *testing.T) {
tt.ws(t, tt.c.IsWebSocket())
})
}
}
func TestContext_Bind(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
c := e.NewContext(req, nil)
u := new(user)
req.Header.Add(HeaderContentType, MIMEApplicationJSON)
err := c.Bind(u)
assert.NoError(t, err)
assert.Equal(t, &user{ID: 1, Name: "Jon Snow"}, u)
}
func TestContext_RealIP(t *testing.T) {
tests := []struct {
c *Context
s string
}{
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}},
},
},
"127.0.0.1",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1,127.0.1.1"}},
},
},
"127.0.0.1",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}},
},
},
"127.0.0.1",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348], 2001:db8::1, "}},
},
},
"2001:db8:85a3:8d3:1319:8a2e:370:7348",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348],[2001:db8::1]"}},
},
},
"2001:db8:85a3:8d3:1319:8a2e:370:7348",
},
{
&Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedFor: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}},
},
},
"2001:db8:85a3:8d3:1319:8a2e:370:7348",
},
{
&Context{
request: &http.Request{
Header: http.Header{
"X-Real-Ip": []string{"192.168.0.1"},
},
},
},
"192.168.0.1",
},
{
&Context{
request: &http.Request{
Header: http.Header{
"X-Real-Ip": []string{"[2001:db8::1]"},
},
},
},
"2001:db8::1",
},
{
&Context{
request: &http.Request{
RemoteAddr: "89.89.89.89:1654",
},
},
"89.89.89.89",
},
}
for _, tt := range tests {
assert.Equal(t, tt.s, tt.c.RealIP())
}
}
func TestContext_File(t *testing.T) {
var testCases = []struct {
whenFS fs.FS
name string
whenFile string
expectError string
expectStartsWith []byte
expectStatus int
}{
{
name: "ok, from default file system",
whenFile: "_fixture/images/walle.png",
whenFS: nil,
expectStatus: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "ok, from custom file system",
whenFile: "walle.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, not existent file",
whenFile: "not.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: nil,
expectError: "Not Found",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
if tc.whenFS != nil {
e.Filesystem = tc.whenFS
}
handler := func(ec *Context) error {
return ec.File(tc.whenFile)
}
req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
assert.Equal(t, tc.expectStatus, rec.Code)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, body)
})
}
}
func TestContext_FileFS(t *testing.T) {
var testCases = []struct {
whenFS fs.FS
name string
whenFile string
expectError string
expectStartsWith []byte
expectStatus int
}{
{
name: "ok",
whenFile: "walle.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, not existent file",
whenFile: "not.png",
whenFS: os.DirFS("_fixture/images"),
expectStatus: http.StatusOK,
expectStartsWith: nil,
expectError: "Not Found",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
handler := func(ec *Context) error {
return ec.FileFS(tc.whenFile, tc.whenFS)
}
req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := handler(c)
assert.Equal(t, tc.expectStatus, rec.Code)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, body)
})
}
}
func TestLogger(t *testing.T) {
e := New()
c := e.NewContext(nil, nil)
log1 := c.Logger()
assert.NotNil(t, log1)
assert.Equal(t, e.Logger, log1)
customLogger := slog.New(slog.NewTextHandler(os.Stdout, nil))
c.SetLogger(customLogger)
assert.Equal(t, customLogger, c.Logger())
// Resetting the context returns the initial Echo logger
c.Reset(nil, nil)
assert.Equal(t, e.Logger, c.Logger())
}
func TestRouteInfo(t *testing.T) {
e := New()
c := e.NewContext(nil, nil)
orgRI := RouteInfo{
Name: "root",
Method: http.MethodGet,
Path: "/*",
Parameters: []string{"*"},
}
c.route = &orgRI
ri := c.RouteInfo()
assert.Equal(t, orgRI, ri)
// Test mutability when middlewares start to change things
// RouteInfo inside context will not be affected when returned instance is changed
expect := orgRI.Clone()
ri.Path = "changed"
ri.Parameters[0] = "changed"
assert.Equal(t, expect, c.RouteInfo())
// RouteInfo inside context will not be affected when returned instance is changed
expect = c.RouteInfo()
orgRI.Name = "changed"
assert.NotEqual(t, expect, c.RouteInfo())
}
================================================
FILE: echo.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
/*
Package echo implements high performance, minimalist Go web framework.
Example:
package main
import (
"log/slog"
"net/http"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware"
)
// Handler
func hello(c *echo.Context) error {
return c.String(http.StatusOK, "Hello, World!")
}
func main() {
// Echo instance
e := echo.New()
// Middleware
e.Use(middleware.RequestLogger())
e.Use(middleware.Recover())
// Routes
e.GET("/", hello)
// Start server
if err := e.Start(":8080"); err != nil {
slog.Error("failed to start server", "error", err)
}
}
Learn more at https://echo.labstack.com
*/
package echo
import (
stdContext "context"
"encoding/json"
"errors"
"fmt"
"io/fs"
"log/slog"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"syscall"
)
// Echo is the top-level framework instance.
//
// Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these
// fields from handlers/middlewares and changing field values at the same time leads to data-races.
// Same rule applies to adding new routes after server has been started - Adding a route is not Goroutine safe action.
type Echo struct {
serveHTTPFunc func(http.ResponseWriter, *http.Request)
Binder Binder
Filesystem fs.FS
Renderer Renderer
Validator Validator
JSONSerializer JSONSerializer
IPExtractor IPExtractor
OnAddRoute func(route Route) error
HTTPErrorHandler HTTPErrorHandler
Logger *slog.Logger
contextPool sync.Pool
router Router
// premiddleware are middlewares that are called before routing is done
premiddleware []MiddlewareFunc
// middleware are middlewares that are called after routing is done and before handler is called
middleware []MiddlewareFunc
contextPathParamAllocSize atomic.Int32
// formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm)
formParseMaxMemory int64
}
// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces.
type JSONSerializer interface {
Serialize(c *Context, target any, indent string) error
Deserialize(c *Context, target any) error
}
// HTTPErrorHandler is a centralized HTTP error handler.
type HTTPErrorHandler func(c *Context, err error)
// HandlerFunc defines a function to serve HTTP requests.
type HandlerFunc func(c *Context) error
// MiddlewareFunc defines a function to process middleware.
type MiddlewareFunc func(next HandlerFunc) HandlerFunc
// MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking.
type MiddlewareConfigurator interface {
ToMiddleware() (MiddlewareFunc, error)
}
// Validator is the interface that wraps the Validate function.
type Validator interface {
Validate(i any) error
}
// MIME types
const (
// MIMEApplicationJSON JavaScript Object Notation (JSON) https://www.rfc-editor.org/rfc/rfc8259
MIMEApplicationJSON = "application/json"
// Deprecated: Please use MIMEApplicationJSON instead. JSON should be encoded using UTF-8 by default.
// No "charset" parameter is defined for this registration.
// Adding one really has no effect on compliant recipients.
// See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1n"
MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8
MIMEApplicationJavaScript = "application/javascript"
MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8
MIMEApplicationXML = "application/xml"
MIMEApplicationXMLCharsetUTF8 = MIMEApplicationXML + "; " + charsetUTF8
MIMETextXML = "text/xml"
MIMETextXMLCharsetUTF8 = MIMETextXML + "; " + charsetUTF8
MIMEApplicationForm = "application/x-www-form-urlencoded"
MIMEApplicationProtobuf = "application/protobuf"
MIMEApplicationMsgpack = "application/msgpack"
MIMETextHTML = "text/html"
MIMETextHTMLCharsetUTF8 = MIMETextHTML + "; " + charsetUTF8
MIMETextPlain = "text/plain"
MIMETextPlainCharsetUTF8 = MIMETextPlain + "; " + charsetUTF8
MIMEMultipartForm = "multipart/form-data"
MIMEOctetStream = "application/octet-stream"
)
const (
charsetUTF8 = "charset=UTF-8"
// PROPFIND Method can be used on collection and property resources.
PROPFIND = "PROPFIND"
// REPORT Method can be used to get information about a resource, see rfc 3253
REPORT = "REPORT"
// RouteNotFound is special method type for routes handling "route not found" (404) cases
RouteNotFound = "echo_route_not_found"
// RouteAny is special method type that matches any HTTP method in request. Any has lower
// priority that other methods that have been registered with Router to that path.
RouteAny = "echo_route_any"
)
// Headers
const (
HeaderAccept = "Accept"
HeaderAcceptEncoding = "Accept-Encoding"
// HeaderAllow is the name of the "Allow" header field used to list the set of methods
// advertised as supported by the target resource. Returning an Allow header is mandatory
// for status 405 (method not found) and useful for the OPTIONS method in responses.
// See RFC 7231: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
HeaderAllow = "Allow"
HeaderAuthorization = "Authorization"
HeaderContentDisposition = "Content-Disposition"
HeaderContentEncoding = "Content-Encoding"
HeaderContentLength = "Content-Length"
HeaderContentType = "Content-Type"
HeaderCookie = "Cookie"
HeaderSetCookie = "Set-Cookie"
HeaderIfModifiedSince = "If-Modified-Since"
HeaderLastModified = "Last-Modified"
HeaderLocation = "Location"
HeaderRetryAfter = "Retry-After"
HeaderUpgrade = "Upgrade"
HeaderVary = "Vary"
HeaderWWWAuthenticate = "WWW-Authenticate"
HeaderXForwardedFor = "X-Forwarded-For"
HeaderXForwardedProto = "X-Forwarded-Proto"
HeaderXForwardedProtocol = "X-Forwarded-Protocol"
HeaderXForwardedSsl = "X-Forwarded-Ssl"
HeaderXUrlScheme = "X-Url-Scheme"
HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
HeaderXRealIP = "X-Real-Ip"
HeaderXRequestID = "X-Request-Id"
HeaderXCorrelationID = "X-Correlation-Id"
HeaderXRequestedWith = "X-Requested-With"
HeaderServer = "Server"
// HeaderOrigin request header indicates the origin (scheme, hostname, and port) that caused the request.
// See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
HeaderOrigin = "Origin"
HeaderCacheControl = "Cache-Control"
HeaderConnection = "Connection"
// Access control
HeaderAccessControlRequestMethod = "Access-Control-Request-Method"
HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers"
HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin"
HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods"
HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers"
HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials"
HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers"
HeaderAccessControlMaxAge = "Access-Control-Max-Age"
// Security
HeaderStrictTransportSecurity = "Strict-Transport-Security"
HeaderXContentTypeOptions = "X-Content-Type-Options"
HeaderXXSSProtection = "X-XSS-Protection"
HeaderXFrameOptions = "X-Frame-Options"
HeaderContentSecurityPolicy = "Content-Security-Policy"
HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only"
HeaderXCSRFToken = "X-CSRF-Token" // #nosec G101
HeaderReferrerPolicy = "Referrer-Policy"
// HeaderSecFetchSite fetch metadata request header indicates the relationship between a request initiator's
// origin and the origin of the requested resource.
// See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site
HeaderSecFetchSite = "Sec-Fetch-Site"
)
// Config is configuration for NewWithConfig function
type Config struct {
// Logger is the slog logger instance used for application-wide structured logging.
// If not set, a default TextHandler writing to stdout is created.
Logger *slog.Logger
// HTTPErrorHandler is the centralized error handler that processes errors returned
// by handlers and middleware, converting them to appropriate HTTP responses.
// If not set, DefaultHTTPErrorHandler(false) is used.
HTTPErrorHandler HTTPErrorHandler
// Router is the HTTP request router responsible for matching URLs to handlers
// using a radix tree-based algorithm.
// If not set, NewRouter(RouterConfig{}) is used.
Router Router
// OnAddRoute is an optional callback hook executed when routes are registered.
// Useful for route validation, logging, or custom route processing.
// If not set, no callback is executed.
OnAddRoute func(route Route) error
// Filesystem is the fs.FS implementation used for serving static files.
// Supports os.DirFS, embed.FS, and custom implementations.
// If not set, defaults to current working directory.
Filesystem fs.FS
// Binder handles automatic data binding from HTTP requests to Go structs.
// Supports JSON, XML, form data, query parameters, and path parameters.
// If not set, DefaultBinder is used.
Binder Binder
// Validator provides optional struct validation after data binding.
// Commonly used with third-party validation libraries.
// If not set, Context.Validate() returns ErrValidatorNotRegistered.
Validator Validator
// Renderer provides template rendering for generating HTML responses.
// Requires integration with a template engine like html/template.
// If not set, Context.Render() returns ErrRendererNotRegistered.
Renderer Renderer
// JSONSerializer handles JSON encoding and decoding for HTTP requests/responses.
// Can be replaced with faster alternatives like jsoniter or sonic.
// If not set, DefaultJSONSerializer using encoding/json is used.
JSONSerializer JSONSerializer
// IPExtractor defines the strategy for extracting the real client IP address
// from requests, particularly important when behind proxies or load balancers.
// Used for rate limiting, access control, and logging.
// If not set, falls back to checking X-Forwarded-For and X-Real-IP headers.
IPExtractor IPExtractor
// FormParseMaxMemory is default value for memory limit that is used
// when parsing multipart forms (See (*http.Request).ParseMultipartForm)
FormParseMaxMemory int64
}
// NewWithConfig creates an instance of Echo with given configuration.
func NewWithConfig(config Config) *Echo {
e := New()
if config.Logger != nil {
e.Logger = config.Logger
}
if config.HTTPErrorHandler != nil {
e.HTTPErrorHandler = config.HTTPErrorHandler
}
if config.Router != nil {
e.router = config.Router
}
if config.OnAddRoute != nil {
e.OnAddRoute = config.OnAddRoute
}
if config.Filesystem != nil {
e.Filesystem = config.Filesystem
}
if config.Binder != nil {
e.Binder = config.Binder
}
if config.Validator != nil {
e.Validator = config.Validator
}
if config.Renderer != nil {
e.Renderer = config.Renderer
}
if config.JSONSerializer != nil {
e.JSONSerializer = config.JSONSerializer
}
if config.IPExtractor != nil {
e.IPExtractor = config.IPExtractor
}
if config.FormParseMaxMemory > 0 {
e.formParseMaxMemory = config.FormParseMaxMemory
}
return e
}
// New creates an instance of Echo.
func New() *Echo {
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
e := &Echo{
Logger: logger,
Filesystem: newDefaultFS(),
Binder: &DefaultBinder{},
JSONSerializer: &DefaultJSONSerializer{},
formParseMaxMemory: defaultMemory,
}
e.serveHTTPFunc = e.serveHTTP
e.router = NewRouter(RouterConfig{})
e.HTTPErrorHandler = DefaultHTTPErrorHandler(false)
e.contextPool.New = func() any {
return newContext(nil, nil, e)
}
return e
}
// NewContext returns a new Context instance.
//
// Note: both request and response can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway
// these arguments are useful when creating context for tests and cases like that.
func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context {
return newContext(r, w, e)
}
// Router returns the default router.
func (e *Echo) Router() Router {
return e.router
}
// DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response
// with status code. `exposeError` parameter decides if returned message will contain also error message or not
//
// Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately)
// Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error).
// When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from
// handler. Then the error that global error handler received will be ignored because we have already "committed" the
// response and status code header has been sent to the client.
func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler {
return func(c *Context, err error) {
if r, _ := UnwrapResponse(c.response); r != nil && r.Committed {
return
}
code := http.StatusInternalServerError
var sc HTTPStatusCoder
if errors.As(err, &sc) {
if tmp := sc.StatusCode(); tmp != 0 {
code = tmp
}
}
var result any
switch m := sc.(type) {
case json.Marshaler: // this type knows how to format itself to JSON
result = m
case *HTTPError:
sText := m.Message
if sText == "" {
sText = http.StatusText(code)
}
msg := map[string]any{"message": sText}
if exposeError {
if wrappedErr := m.Unwrap(); wrappedErr != nil {
msg["error"] = wrappedErr.Error()
}
}
result = msg
default:
msg := map[string]any{"message": http.StatusText(code)}
if exposeError {
msg["error"] = err.Error()
}
result = msg
}
var cErr error
if c.Request().Method == http.MethodHead { // Issue #608
cErr = c.NoContent(code)
} else {
cErr = c.JSON(code, result)
}
if cErr != nil {
c.Logger().Error("echo default error handler failed to send error to client", "error", cErr) // truly rare case. ala client already disconnected
}
}
}
// Pre adds middleware to the chain which is run before router tries to find matching route.
// Meaning middleware is executed even for 404 (not found) cases.
func (e *Echo) Pre(middleware ...MiddlewareFunc) {
e.premiddleware = append(e.premiddleware, middleware...)
}
// Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed.
func (e *Echo) Use(middleware ...MiddlewareFunc) {
e.middleware = append(e.middleware, middleware...)
}
// CONNECT registers a new CONNECT route for a path with matching handler in the
// router with optional route-level middleware. Panics on error.
func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodConnect, path, h, m...)
}
// DELETE registers a new DELETE route for a path with matching handler in the router
// with optional route-level middleware. Panics on error.
func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodDelete, path, h, m...)
}
// GET registers a new GET route for a path with matching handler in the router
// with optional route-level middleware. Panics on error.
func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodGet, path, h, m...)
}
// HEAD registers a new HEAD route for a path with matching handler in the
// router with optional route-level middleware. Panics on error.
func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodHead, path, h, m...)
}
// OPTIONS registers a new OPTIONS route for a path with matching handler in the
// router with optional route-level middleware. Panics on error.
func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodOptions, path, h, m...)
}
// PATCH registers a new PATCH route for a path with matching handler in the
// router with optional route-level middleware. Panics on error.
func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodPatch, path, h, m...)
}
// POST registers a new POST route for a path with matching handler in the
// router with optional route-level middleware. Panics on error.
func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodPost, path, h, m...)
}
// PUT registers a new PUT route for a path with matching handler in the
// router with optional route-level middleware. Panics on error.
func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodPut, path, h, m...)
}
// TRACE registers a new TRACE route for a path with matching handler in the
// router with optional route-level middleware. Panics on error.
func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodTrace, path, h, m...)
}
// RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases)
// for current request URL.
// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with
// wildcard/match-any character (`/*`, `/download/*` etc).
//
// Example: `e.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })`
func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(RouteNotFound, path, h, m...)
}
// Any registers a new route for all HTTP methods (supported by Echo) and path with matching handler
// in the router with optional route-level middleware.
//
// Note: this method only adds specific set of supported HTTP methods as handler and is not true
// "catch-any-arbitrary-method" way of matching requests.
func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
return e.Add(RouteAny, path, handler, middleware...)
}
// Match registers a new route for multiple HTTP methods and path with matching
// handler in the router with optional route-level middleware. Panics on error.
func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
errs := make([]error, 0)
ris := make(Routes, 0)
for _, m := range methods {
ri, err := e.AddRoute(Route{
Method: m,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
errs = append(errs, err)
continue
}
ris = append(ris, ri)
}
if len(errs) > 0 {
panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ris
}
// Static registers a new route with path prefix to serve static files from the provided root directory.
func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo {
subFs := MustSubFS(e.Filesystem, fsRoot)
return e.Add(
http.MethodGet,
pathPrefix+"*",
StaticDirectoryHandler(subFs, false),
middleware...,
)
}
// StaticFS registers a new route with path prefix to serve static files from the provided file system.
//
// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
// including `assets/images` as their prefix.
func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo {
return e.Add(
http.MethodGet,
pathPrefix+"*",
StaticDirectoryHandler(filesystem, false),
middleware...,
)
}
// StaticDirectoryHandler creates handler function to serve files from provided file system
// When disablePathUnescaping is set then file name from path is not unescaped and is served as is.
func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc {
return func(c *Context) error {
p := c.Param("*")
if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice
tmpPath, err := url.PathUnescape(p)
if err != nil {
return fmt.Errorf("failed to unescape path variable: %w", err)
}
p = tmpPath
}
// fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid
name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/")))
fi, err := fs.Stat(fileSystem, name)
if err != nil {
return ErrNotFound
}
// If the request is for a directory and does not end with "/"
p = c.Request().URL.Path // path must not be empty.
if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' {
// Redirect to ends with "/"
return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/"))
}
return fsFile(c, name, fileSystem)
}
}
// FileFS registers a new route with path to serve file from the provided file system.
func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo {
return e.GET(path, StaticFileHandler(file, filesystem), m...)
}
// StaticFileHandler creates handler function to serve file from provided file system
func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc {
return func(c *Context) error {
return fsFile(c, file, filesystem)
}
}
// File registers a new route with path to serve a static file with optional route-level middleware. Panics on error.
func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
handler := func(c *Context) error {
return c.File(file)
}
return e.Add(http.MethodGet, path, handler, middleware...)
}
// AddRoute registers a new Route with default host Router
func (e *Echo) AddRoute(route Route) (RouteInfo, error) {
return e.add(route)
}
func (e *Echo) add(route Route) (RouteInfo, error) {
if e.OnAddRoute != nil {
if err := e.OnAddRoute(route); err != nil {
return RouteInfo{}, err
}
}
ri, err := e.router.Add(route)
if err != nil {
return RouteInfo{}, err
}
paramsCount := int32(len(ri.Parameters)) // #nosec G115
if paramsCount > e.contextPathParamAllocSize.Load() {
e.contextPathParamAllocSize.Store(paramsCount)
}
return ri, nil
}
// Add registers a new route for an HTTP method and path with matching handler
// in the router with optional route-level middleware.
func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
ri, err := e.add(
Route{
Method: method,
Path: path,
Handler: handler,
Middlewares: middleware,
Name: "",
},
)
if err != nil {
panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ri
}
// Group creates a new router group with prefix and optional group-level middleware.
func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) {
g = &Group{prefix: prefix, echo: e}
g.Use(m...)
return
}
// PreMiddlewares returns registered pre middlewares. These are middleware to the chain
// which are run before router tries to find matching route.
// Use this method to build your own ServeHTTP method.
//
// NOTE: returned slice is not a copy. Do not mutate.
func (e *Echo) PreMiddlewares() []MiddlewareFunc {
return e.premiddleware
}
// Middlewares returns registered route level middlewares. Does not contain any group level
// middlewares. Use this method to build your own ServeHTTP method.
//
// NOTE: returned slice is not a copy. Do not mutate.
func (e *Echo) Middlewares() []MiddlewareFunc {
return e.middleware
}
// AcquireContext returns an empty `Context` instance from the pool.
// You must return the context by calling `ReleaseContext()`.
func (e *Echo) AcquireContext() *Context {
return e.contextPool.Get().(*Context)
}
// ReleaseContext returns the `Context` instance back to the pool.
// You must call it after `AcquireContext()`.
func (e *Echo) ReleaseContext(c *Context) {
e.contextPool.Put(c)
}
// ServeHTTP implements `http.Handler` interface, which serves HTTP requests.
func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
e.serveHTTPFunc(w, r)
}
// serveHTTP implements `http.Handler` interface, which serves HTTP requests.
func (e *Echo) serveHTTP(w http.ResponseWriter, r *http.Request) {
c := e.contextPool.Get().(*Context)
defer e.contextPool.Put(c)
c.Reset(r, w)
var h HandlerFunc
if e.premiddleware == nil {
h = applyMiddleware(e.router.Route(c), e.middleware...)
} else {
h = func(cc *Context) error {
h1 := applyMiddleware(e.router.Route(cc), e.middleware...)
return h1(cc)
}
h = applyMiddleware(h, e.premiddleware...)
}
// Execute chain
if err := h(c); err != nil {
e.HTTPErrorHandler(c, err)
}
}
// Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by
// sending os.Interrupt signal with `ctrl+c`. Method returns only errors that are not http.ErrServerClosed.
//
// Note: this method is created for use in examples/demos and is deliberately simple without providing configuration
// options.
//
// In need of customization use:
//
// ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
// defer cancel()
// sc := echo.StartConfig{Address: ":8080"}
// if err := sc.Start(ctx, e); err != nil && !errors.Is(err, http.ErrServerClosed) {
// slog.Error(err.Error())
// }
//
// // or standard library `http.Server`
//
// s := http.Server{Addr: ":8080", Handler: e}
// if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
// slog.Error(err.Error())
// }
func (e *Echo) Start(address string) error {
sc := StartConfig{Address: address}
ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt, syscall.SIGTERM) // start shutdown process on ctrl+c
defer cancel()
return sc.Start(ctx, e)
}
// WrapHandler wraps `http.Handler` into `echo.HandlerFunc`.
func WrapHandler(h http.Handler) HandlerFunc {
return func(c *Context) error {
req := c.Request()
req.Pattern = c.Path()
for _, p := range c.PathValues() {
req.SetPathValue(p.Name, p.Value)
}
h.ServeHTTP(c.Response(), req)
return nil
}
}
// WrapMiddleware wraps `func(http.Handler) http.Handler` into `echo.MiddlewareFunc`
func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc {
return func(next HandlerFunc) HandlerFunc {
return func(c *Context) (err error) {
req := c.Request()
req.Pattern = c.Path()
for _, p := range c.PathValues() {
req.SetPathValue(p.Name, p.Value)
}
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c.SetRequest(r)
c.SetResponse(NewResponse(w, c.echo.Logger))
err = next(c)
})).ServeHTTP(c.Response(), req)
return
}
}
}
func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc {
for i := len(middleware) - 1; i >= 0; i-- {
h = middleware[i](h)
}
return h
}
// defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open`
// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images`
// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break
// all old applications that rely on being able to traverse up from current executable run path.
// NB: private because you really should use fs.FS implementation instances
type defaultFS struct {
fs fs.FS
prefix string
}
func newDefaultFS() *defaultFS {
dir, _ := os.Getwd()
return &defaultFS{
prefix: dir,
fs: os.DirFS(dir),
}
}
func (fs defaultFS) Open(name string) (fs.File, error) {
return fs.fs.Open(name)
}
func subFS(currentFs fs.FS, root string) (fs.FS, error) {
root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows
if dFS, ok := currentFs.(*defaultFS); ok {
// we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS.
// fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we
// were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs
if !filepath.IsAbs(root) {
root = filepath.Join(dFS.prefix, root)
}
return &defaultFS{
prefix: root,
fs: os.DirFS(root),
}, nil
}
return fs.Sub(currentFs, root)
}
// MustSubFS creates sub FS from current filesystem or panic on failure.
// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules.
//
// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with
// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to
// create sub fs which uses necessary prefix for directory path.
func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS {
subFs, err := subFS(currentFs, fsRoot)
if err != nil {
panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err))
}
return subFs
}
func sanitizeURI(uri string) string {
// double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri
// we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash
if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') {
uri = "/" + strings.TrimLeft(uri, `/\`)
}
return uri
}
================================================
FILE: echo_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"bytes"
stdContext "context"
"errors"
"fmt"
"io/fs"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
type user struct {
ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"`
Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"`
}
const (
userJSON = `{"id":1,"name":"Jon Snow"}`
usersJSON = `[{"id":1,"name":"Jon Snow"}]`
userXML = `1 Jon Snow `
userForm = `id=1&name=Jon Snow`
invalidContent = "invalid content"
userJSONInvalidType = `{"id":"1","name":"Jon Snow"}`
userXMLConvertNumberError = `Number one Jon Snow `
userXMLUnsupportedTypeError = `<>Number one>Jon Snow `
)
const userJSONPretty = `{
"id": 1,
"name": "Jon Snow"
}`
const userXMLPretty = `
1
Jon Snow
`
var dummyQuery = url.Values{"dummy": []string{"useless"}}
func TestEcho(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Router
assert.NotNil(t, e.Router())
e.HTTPErrorHandler(c, errors.New("error"))
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestNewWithConfig(t *testing.T) {
e := NewWithConfig(Config{})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.GET("/", func(c *Context) error {
return c.String(http.StatusTeapot, "Hello, World!")
})
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTeapot, rec.Code)
assert.Equal(t, `Hello, World!`, rec.Body.String())
}
func TestEcho_StaticFS(t *testing.T) {
var testCases = []struct {
givenFs fs.FS
name string
givenPrefix string
givenFsRoot string
whenURL string
expectHeaderLocation string
expectBodyStartsWith string
expectStatus int
}{
{
name: "ok",
givenPrefix: "/images",
givenFs: os.DirFS("./_fixture/images"),
whenURL: "/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "ok, from sub fs",
givenPrefix: "/images",
givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"),
whenURL: "/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "No file",
givenPrefix: "/images",
givenFs: os.DirFS("_fixture/scripts"),
whenURL: "/images/bolt.png",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory",
givenPrefix: "/images",
givenFs: os.DirFS("_fixture/images"),
whenURL: "/images/",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory Redirect",
givenPrefix: "/",
givenFs: os.DirFS("_fixture/"),
whenURL: "/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory Redirect with non-root path",
givenPrefix: "/static",
givenFs: os.DirFS("_fixture"),
whenURL: "/static",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/static/",
expectBodyStartsWith: "",
},
{
name: "Prefixed directory 404 (request URL without slash)",
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
givenFs: os.DirFS("_fixture"),
whenURL: "/folder", // no trailing slash
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Prefixed directory redirect (without slash redirect to slash)",
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
givenFs: os.DirFS("_fixture"),
whenURL: "/folder", // no trailing slash
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory with index.html",
givenPrefix: "/",
givenFs: os.DirFS("_fixture"),
whenURL: "/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "",
},
{
name: "Prefixed directory with index.html (prefix ending with slash)",
givenPrefix: "/assets/",
givenFs: os.DirFS("_fixture"),
whenURL: "/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "",
},
{
name: "Prefixed directory with index.html (prefix ending without slash)",
givenPrefix: "/assets",
givenFs: os.DirFS("_fixture"),
whenURL: "/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "",
},
{
name: "Sub-directory with index.html",
givenPrefix: "/",
givenFs: os.DirFS("_fixture"),
whenURL: "/folder/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "",
},
{
name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/",
givenFs: os.DirFS("_fixture/"),
whenURL: `/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/",
givenFs: os.DirFS("_fixture/"),
whenURL: `/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "open redirect vulnerability",
givenPrefix: "/",
givenFs: os.DirFS("_fixture/"),
whenURL: "/open.redirect.hackercom%2f..",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/open.redirect.hackercom/../", // location starting with `//open` would be very bad
expectBodyStartsWith: "",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
tmpFs := tc.givenFs
if tc.givenFsRoot != "" {
tmpFs = MustSubFS(tmpFs, tc.givenFsRoot)
}
e.StaticFS(tc.givenPrefix, tmpFs)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String()
if tc.expectBodyStartsWith != "" {
assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
} else {
assert.Equal(t, "", body)
}
if tc.expectHeaderLocation != "" {
assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
} else {
_, ok := rec.Result().Header["Location"]
assert.False(t, ok)
}
})
}
}
func TestEcho_FileFS(t *testing.T) {
var testCases = []struct {
whenFS fs.FS
name string
whenPath string
whenFile string
givenURL string
expectStartsWith []byte
expectCode int
}{
{
name: "ok",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "walle.png",
givenURL: "/walle",
expectCode: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, requesting invalid path",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "walle.png",
givenURL: "/walle.png",
expectCode: http.StatusNotFound,
expectStartsWith: []byte(`{"message":"Not Found"}`),
},
{
name: "nok, serving not existent file from filesystem",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "not-existent.png",
givenURL: "/walle",
expectCode: http.StatusNotFound,
expectStartsWith: []byte(`{"message":"Not Found"}`),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectCode, rec.Code)
body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, body)
})
}
}
func TestEcho_StaticPanic(t *testing.T) {
var testCases = []struct {
name string
givenRoot string
}{
{
name: "panics for ../",
givenRoot: "../assets",
},
{
name: "panics for /",
givenRoot: "/assets",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.Filesystem = os.DirFS("./")
assert.Panics(t, func() {
e.Static("../assets", tc.givenRoot)
})
})
}
}
func TestEchoStaticRedirectIndex(t *testing.T) {
e := New()
// HandlerFunc
ri := e.Static("/static", "_fixture")
assert.Equal(t, http.MethodGet, ri.Method)
assert.Equal(t, "/static*", ri.Path)
assert.Equal(t, "GET:/static*", ri.Name)
assert.Equal(t, []string{"*"}, ri.Parameters)
ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
defer cancel()
addr, err := startOnRandomPort(ctx, e)
if err != nil {
assert.Fail(t, err.Error())
}
code, body, err := doGet(fmt.Sprintf("http://%v/static", addr))
assert.NoError(t, err)
assert.True(t, strings.HasPrefix(body, ""))
assert.Equal(t, http.StatusOK, code)
}
func TestEchoFile(t *testing.T) {
var testCases = []struct {
name string
givenPath string
givenFile string
whenPath string
expectStartsWith string
expectCode int
}{
{
name: "ok",
givenPath: "/walle",
givenFile: "_fixture/images/walle.png",
whenPath: "/walle",
expectCode: http.StatusOK,
expectStartsWith: string([]byte{0x89, 0x50, 0x4e}),
},
{
name: "ok with relative path",
givenPath: "/",
givenFile: "./go.mod",
whenPath: "/",
expectCode: http.StatusOK,
expectStartsWith: "module github.com/labstack/echo/v",
},
{
name: "nok file does not exist",
givenPath: "/",
givenFile: "./this-file-does-not-exist",
whenPath: "/",
expectCode: http.StatusNotFound,
expectStartsWith: "{\"message\":\"Not Found\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New() // we are using echo.defaultFS instance
e.File(tc.givenPath, tc.givenFile)
c, b := request(http.MethodGet, tc.whenPath, e)
assert.Equal(t, tc.expectCode, c)
if len(b) > len(tc.expectStartsWith) {
b = b[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, b)
})
}
}
func TestEchoMiddleware(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
e.Pre(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
// before route match is found RouteInfo does not exist
assert.Equal(t, RouteInfo{}, c.RouteInfo())
buf.WriteString("-1")
return next(c)
}
})
e.Use(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
buf.WriteString("1")
return next(c)
}
})
e.Use(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
buf.WriteString("2")
return next(c)
}
})
e.Use(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
buf.WriteString("3")
return next(c)
}
})
// Route
e.GET("/", func(c *Context) error {
return c.String(http.StatusOK, "OK")
})
c, b := request(http.MethodGet, "/", e)
assert.Equal(t, "-1123", buf.String())
assert.Equal(t, http.StatusOK, c)
assert.Equal(t, "OK", b)
}
func TestEchoMiddlewareError(t *testing.T) {
e := New()
e.Use(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
return errors.New("error")
}
})
e.GET("/", notFoundHandler)
c, _ := request(http.MethodGet, "/", e)
assert.Equal(t, http.StatusInternalServerError, c)
}
func TestEchoHandler(t *testing.T) {
e := New()
// HandlerFunc
e.GET("/ok", func(c *Context) error {
return c.String(http.StatusOK, "OK")
})
c, b := request(http.MethodGet, "/ok", e)
assert.Equal(t, http.StatusOK, c)
assert.Equal(t, "OK", b)
}
func TestEchoWrapHandler(t *testing.T) {
e := New()
var actualID string
var actualPattern string
e.GET("/:id", WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test"))
actualID = r.PathValue("id")
actualPattern = r.Pattern
})))
req := httptest.NewRequest(http.MethodGet, "/123", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "test", rec.Body.String())
assert.Equal(t, "123", actualID)
assert.Equal(t, "/:id", actualPattern)
}
func TestEchoWrapMiddleware(t *testing.T) {
e := New()
var actualID string
var actualPattern string
e.Use(WrapMiddleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualID = r.PathValue("id")
actualPattern = r.Pattern
h.ServeHTTP(w, r)
})
}))
e.GET("/:id", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
req := httptest.NewRequest(http.MethodGet, "/123", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTeapot, rec.Code)
assert.Equal(t, "OK", rec.Body.String())
assert.Equal(t, "123", actualID)
assert.Equal(t, "/:id", actualPattern)
}
func TestEchoConnect(t *testing.T) {
e := New()
ri := e.CONNECT("/", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodConnect, ri.Method)
assert.Equal(t, "/", ri.Path)
assert.Equal(t, http.MethodConnect+":/", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodConnect, "/", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, "OK", body)
}
func TestEchoDelete(t *testing.T) {
e := New()
ri := e.DELETE("/", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodDelete, ri.Method)
assert.Equal(t, "/", ri.Path)
assert.Equal(t, http.MethodDelete+":/", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodDelete, "/", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, "OK", body)
}
func TestEchoGet(t *testing.T) {
e := New()
ri := e.GET("/", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodGet, ri.Method)
assert.Equal(t, "/", ri.Path)
assert.Equal(t, http.MethodGet+":/", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodGet, "/", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, "OK", body)
}
func TestEchoHead(t *testing.T) {
e := New()
ri := e.HEAD("/", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodHead, ri.Method)
assert.Equal(t, "/", ri.Path)
assert.Equal(t, http.MethodHead+":/", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodHead, "/", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, "OK", body)
}
func TestEchoOptions(t *testing.T) {
e := New()
ri := e.OPTIONS("/", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodOptions, ri.Method)
assert.Equal(t, "/", ri.Path)
assert.Equal(t, http.MethodOptions+":/", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodOptions, "/", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, "OK", body)
}
func TestEchoPatch(t *testing.T) {
e := New()
ri := e.PATCH("/", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPatch, ri.Method)
assert.Equal(t, "/", ri.Path)
assert.Equal(t, http.MethodPatch+":/", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodPatch, "/", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, "OK", body)
}
func TestEchoPost(t *testing.T) {
e := New()
ri := e.POST("/", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPost, ri.Method)
assert.Equal(t, "/", ri.Path)
assert.Equal(t, http.MethodPost+":/", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodPost, "/", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, "OK", body)
}
func TestEchoPut(t *testing.T) {
e := New()
ri := e.PUT("/", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPut, ri.Method)
assert.Equal(t, "/", ri.Path)
assert.Equal(t, http.MethodPut+":/", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodPut, "/", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, "OK", body)
}
func TestEchoTrace(t *testing.T) {
e := New()
ri := e.TRACE("/", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodTrace, ri.Method)
assert.Equal(t, "/", ri.Path)
assert.Equal(t, http.MethodTrace+":/", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodTrace, "/", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, "OK", body)
}
func TestEcho_Any(t *testing.T) {
e := New()
ri := e.Any("/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK from ANY")
})
assert.Equal(t, RouteAny, ri.Method)
assert.Equal(t, "/activate", ri.Path)
assert.Equal(t, RouteAny+":/activate", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodTrace, "/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK from ANY`, body)
}
func TestEcho_Any_hasLowerPriority(t *testing.T) {
e := New()
e.Any("/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "ANY")
})
e.GET("/activate", func(c *Context) error {
return c.String(http.StatusLocked, "GET")
})
status, body := request(http.MethodTrace, "/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `ANY`, body)
status, body = request(http.MethodGet, "/activate", e)
assert.Equal(t, http.StatusLocked, status)
assert.Equal(t, `GET`, body)
}
func TestEchoMatch(t *testing.T) { // JFC
e := New()
ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c *Context) error {
return c.String(http.StatusOK, "Match")
})
assert.Len(t, ris, 2)
}
func TestEchoServeHTTPPathEncoding(t *testing.T) {
e := New()
e.GET("/with/slash", func(c *Context) error {
return c.String(http.StatusOK, "/with/slash")
})
e.GET("/:id", func(c *Context) error {
return c.String(http.StatusOK, c.Param("id"))
})
var testCases = []struct {
name string
whenURL string
expectURL string
expectStatus int
}{
{
name: "url with encoding is not decoded for routing",
whenURL: "/with%2Fslash",
expectURL: "with%2Fslash", // `%2F` is not decoded to `/` for routing
expectStatus: http.StatusOK,
},
{
name: "url without encoding is used as is",
whenURL: "/with/slash",
expectURL: "/with/slash",
expectStatus: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
assert.Equal(t, tc.expectURL, rec.Body.String())
})
}
}
func TestEchoGroup(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
e.Use(MiddlewareFunc(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
buf.WriteString("0")
return next(c)
}
}))
h := func(c *Context) error {
return c.NoContent(http.StatusOK)
}
//--------
// Routes
//--------
e.GET("/users", h)
// Group
g1 := e.Group("/group1")
g1.Use(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
buf.WriteString("1")
return next(c)
}
})
g1.GET("", h)
// Nested groups with middleware
g2 := e.Group("/group2")
g2.Use(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
buf.WriteString("2")
return next(c)
}
})
g3 := g2.Group("/group3")
g3.Use(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
buf.WriteString("3")
return next(c)
}
})
g3.GET("", h)
request(http.MethodGet, "/users", e)
assert.Equal(t, "0", buf.String())
buf.Reset()
request(http.MethodGet, "/group1", e)
assert.Equal(t, "01", buf.String())
buf.Reset()
request(http.MethodGet, "/group2/group3", e)
assert.Equal(t, "023", buf.String())
}
func TestEcho_RouteNotFound(t *testing.T) {
var testCases = []struct {
expectRoute any
name string
whenURL string
expectCode int
}{
{
name: "404, route to static not found handler /a/c/xx",
whenURL: "/a/c/xx",
expectRoute: "GET /a/c/xx",
expectCode: http.StatusNotFound,
},
{
name: "404, route to path param not found handler /a/:file",
whenURL: "/a/echo.exe",
expectRoute: "GET /a/:file",
expectCode: http.StatusNotFound,
},
{
name: "404, route to any not found handler /*",
whenURL: "/b/echo.exe",
expectRoute: "GET /*",
expectCode: http.StatusNotFound,
},
{
name: "200, route /a/c/df to /a/c/df",
whenURL: "/a/c/df",
expectRoute: "GET /a/c/df",
expectCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
okHandler := func(c *Context) error {
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
}
notFoundHandler := func(c *Context) error {
return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
}
e.GET("/", okHandler)
e.GET("/a/c/df", okHandler)
e.GET("/a/b*", okHandler)
e.PUT("/*", okHandler)
e.RouteNotFound("/a/c/xx", notFoundHandler) // static
e.RouteNotFound("/a/:file", notFoundHandler) // param
e.RouteNotFound("/*", notFoundHandler) // any
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectCode, rec.Code)
assert.Equal(t, tc.expectRoute, rec.Body.String())
})
}
}
func TestEchoNotFound(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/files", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Code)
}
func TestEchoMethodNotAllowed(t *testing.T) {
e := New()
e.GET("/", func(c *Context) error {
return c.String(http.StatusOK, "Echo!")
})
req := httptest.NewRequest(http.MethodPost, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusMethodNotAllowed, rec.Code)
assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow))
}
func TestEcho_OnAddRoute(t *testing.T) {
exampleRoute := Route{
Method: http.MethodGet,
Path: "/api/files/:id",
Handler: notFoundHandler,
Middlewares: nil,
Name: "x",
}
var testCases = []struct {
whenRoute Route
whenError error
name string
expectError string
expectAdded []string
expectLen int
}{
{
name: "ok",
whenRoute: exampleRoute,
whenError: nil,
expectAdded: []string{"/static", "/api/files/:id"},
expectError: "",
expectLen: 2,
},
{
name: "nok, error is returned",
whenRoute: exampleRoute,
whenError: errors.New("nope"),
expectAdded: []string{"/static"},
expectError: "nope",
expectLen: 1,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
added := make([]string, 0)
cnt := 0
e.OnAddRoute = func(route Route) error {
if cnt > 0 && tc.whenError != nil { // we want to GET /static to succeed for nok tests
return tc.whenError
}
cnt++
added = append(added, route.Path)
return nil
}
e.GET("/static", notFoundHandler)
var err error
_, err = e.AddRoute(tc.whenRoute)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
assert.Len(t, e.Router().Routes(), tc.expectLen)
assert.Equal(t, tc.expectAdded, added)
})
}
}
func TestEchoContext(t *testing.T) {
e := New()
c := e.AcquireContext()
assert.IsType(t, new(Context), c)
e.ReleaseContext(c)
}
func TestPreMiddlewares(t *testing.T) {
e := New()
assert.Equal(t, 0, len(e.PreMiddlewares()))
e.Pre(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
return next(c)
}
})
assert.Equal(t, 1, len(e.PreMiddlewares()))
}
func TestMiddlewares(t *testing.T) {
e := New()
assert.Equal(t, 0, len(e.Middlewares()))
e.Use(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
return next(c)
}
})
assert.Equal(t, 1, len(e.Middlewares()))
}
func TestEcho_Start(t *testing.T) {
e := New()
e.GET("/", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
rndPort, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
defer rndPort.Close()
errChan := make(chan error, 1)
go func() {
errChan <- e.Start(rndPort.Addr().String())
}()
select {
case <-time.After(250 * time.Millisecond):
t.Fatal("start did not error out")
case err := <-errChan:
expectContains := "bind: address already in use"
if runtime.GOOS == "windows" {
expectContains = "bind: Only one usage of each socket address"
}
assert.Contains(t, err.Error(), expectContains)
}
}
func request(method, path string, e *Echo) (int, string) {
req := httptest.NewRequest(method, path, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
return rec.Code, rec.Body.String()
}
type customError struct {
Code int
Message string
}
func (ce *customError) StatusCode() int {
return ce.Code
}
func (ce *customError) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.Message)), nil
}
func (ce *customError) Error() string {
return ce.Message
}
func TestDefaultHTTPErrorHandler(t *testing.T) {
var testCases = []struct {
whenError error
name string
whenMethod string
expectBody string
expectLogged string
expectStatus int
givenExposeError bool
givenLoggerFunc bool
}{
{
name: "ok, expose error = true, HTTPError, no wrapped err",
givenExposeError: true,
whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"},
expectStatus: http.StatusTeapot,
expectBody: `{"message":"my_error"}` + "\n",
},
{
name: "ok, expose error = true, HTTPError + wrapped error",
givenExposeError: true,
whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(errors.New("internal_error")),
expectStatus: http.StatusTeapot,
expectBody: `{"error":"internal_error","message":"my_error"}` + "\n",
},
{
name: "ok, expose error = true, HTTPError + wrapped HTTPError",
givenExposeError: true,
whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}),
expectStatus: http.StatusTeapot,
expectBody: `{"error":"code=418, message=early_error","message":"my_error"}` + "\n",
},
{
name: "ok, expose error = false, HTTPError",
whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"},
expectStatus: http.StatusTeapot,
expectBody: `{"message":"my_error"}` + "\n",
},
{
name: "ok, expose error = false, HTTPError, no message",
whenError: &HTTPError{Code: http.StatusTeapot, Message: ""},
expectStatus: http.StatusTeapot,
expectBody: `{"message":"I'm a teapot"}` + "\n",
},
{
name: "ok, expose error = false, HTTPError + internal HTTPError",
whenError: HTTPError{Code: http.StatusTooEarly, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}),
expectStatus: http.StatusTooEarly,
expectBody: `{"message":"my_error"}` + "\n",
},
{
name: "ok, expose error = true, Error",
givenExposeError: true,
whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
expectStatus: http.StatusInternalServerError,
expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n",
},
{
name: "ok, expose error = false, Error",
whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
expectStatus: http.StatusInternalServerError,
expectBody: `{"message":"Internal Server Error"}` + "\n",
},
{
name: "ok, http.HEAD, expose error = true, Error",
givenExposeError: true,
whenMethod: http.MethodHead,
whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
expectStatus: http.StatusInternalServerError,
expectBody: ``,
},
{
name: "ok, custom error implement MarshalJSON + HTTPStatusCoder",
whenMethod: http.MethodGet,
whenError: &customError{Code: http.StatusTeapot, Message: "custom error msg"},
expectStatus: http.StatusTeapot,
expectBody: `{"x":"custom error msg"}` + "\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
buf := new(bytes.Buffer)
e := New()
e.Logger = slog.New(slog.DiscardHandler)
e.Any("/path", func(c *Context) error {
return tc.whenError
})
e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError)
method := http.MethodGet
if tc.whenMethod != "" {
method = tc.whenMethod
}
c, b := request(method, "/path", e)
assert.Equal(t, tc.expectStatus, c)
assert.Equal(t, tc.expectBody, b)
assert.Equal(t, tc.expectLogged, buf.String())
})
}
}
func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
resp := httptest.NewRecorder()
c := e.NewContext(req, resp)
c.orgResponse.Committed = true
errHandler := DefaultHTTPErrorHandler(false)
errHandler(c, errors.New("my_error"))
assert.Equal(t, http.StatusOK, resp.Code)
}
func benchmarkEchoRoutes(b *testing.B, routes []testRoute) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
u := req.URL
w := httptest.NewRecorder()
b.ReportAllocs()
// Add routes
for _, route := range routes {
e.Add(route.Method, route.Path, func(c *Context) error {
return nil
})
}
// Find routes
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, route := range routes {
req.Method = route.Method
u.Path = route.Path
e.ServeHTTP(w, req)
}
}
}
func BenchmarkEchoStaticRoutes(b *testing.B) {
benchmarkEchoRoutes(b, staticRoutes)
}
func BenchmarkEchoStaticRoutesMisses(b *testing.B) {
benchmarkEchoRoutes(b, staticRoutes)
}
func BenchmarkEchoGitHubAPI(b *testing.B) {
benchmarkEchoRoutes(b, gitHubAPI)
}
func BenchmarkEchoGitHubAPIMisses(b *testing.B) {
benchmarkEchoRoutes(b, gitHubAPI)
}
func BenchmarkEchoParseAPI(b *testing.B) {
benchmarkEchoRoutes(b, parseAPI)
}
================================================
FILE: echotest/context.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echotest
import (
"bytes"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/labstack/echo/v5"
)
// ContextConfig is configuration for creating echo.Context for testing purposes.
type ContextConfig struct {
// Request will be used instead of default `httptest.NewRequest(http.MethodGet, "/", nil)`
Request *http.Request
// Response will be used instead of default `httptest.NewRecorder()`
Response *httptest.ResponseRecorder
// QueryValues will be set as Request.URL.RawQuery value
QueryValues url.Values
// Headers will be set as Request.Header value
Headers http.Header
// PathValues initializes context.PathValues with given value.
PathValues echo.PathValues
// RouteInfo initializes context.RouteInfo() with given value
RouteInfo *echo.RouteInfo
// FormValues creates form-urlencoded form out of given values. If there is no
// `content-type` header it will be set to `application/x-www-form-urlencoded`
// In case Request was not set the Request.Method is set to `POST`
//
// FormValues, MultipartForm and JSONBody are mutually exclusive.
FormValues url.Values
// MultipartForm creates multipart form out of given value. If there is no
// `content-type` header it will be set to `multipart/form-data`
// In case Request was not set the Request.Method is set to `POST`
//
// FormValues, MultipartForm and JSONBody are mutually exclusive.
MultipartForm *MultipartForm
// JSONBody creates JSON body out of given bytes. If there is no
// `content-type` header it will be set to `application/json`
// In case Request was not set the Request.Method is set to `POST`
//
// FormValues, MultipartForm and JSONBody are mutually exclusive.
JSONBody []byte
}
// MultipartForm is used to create multipart form out of given value
type MultipartForm struct {
Fields map[string]string
Files []MultipartFormFile
}
// MultipartFormFile is used to create file in multipart form out of given value
type MultipartFormFile struct {
Fieldname string
Filename string
Content []byte
}
// ToContext converts ContextConfig to echo.Context
func (conf ContextConfig) ToContext(t *testing.T) *echo.Context {
c, _ := conf.ToContextRecorder(t)
return c
}
// ToContextRecorder converts ContextConfig to echo.Context and httptest.ResponseRecorder
func (conf ContextConfig) ToContextRecorder(t *testing.T) (*echo.Context, *httptest.ResponseRecorder) {
if conf.Response == nil {
conf.Response = httptest.NewRecorder()
}
isDefaultRequest := false
if conf.Request == nil {
isDefaultRequest = true
conf.Request = httptest.NewRequest(http.MethodGet, "/", nil)
}
if len(conf.QueryValues) > 0 {
conf.Request.URL.RawQuery = conf.QueryValues.Encode()
}
if len(conf.Headers) > 0 {
conf.Request.Header = conf.Headers
}
if len(conf.FormValues) > 0 {
body := strings.NewReader(url.Values(conf.FormValues).Encode())
conf.Request.Body = io.NopCloser(body)
conf.Request.ContentLength = int64(body.Len())
if conf.Request.Header.Get(echo.HeaderContentType) == "" {
conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
}
if isDefaultRequest {
conf.Request.Method = http.MethodPost
}
} else if conf.MultipartForm != nil {
var body bytes.Buffer
mw := multipart.NewWriter(&body)
for field, value := range conf.MultipartForm.Fields {
if err := mw.WriteField(field, value); err != nil {
t.Fatal(err)
}
}
for _, file := range conf.MultipartForm.Files {
fw, err := mw.CreateFormFile(file.Fieldname, file.Filename)
if err != nil {
t.Fatal(err)
}
if _, err = fw.Write(file.Content); err != nil {
t.Fatal(err)
}
}
if err := mw.Close(); err != nil {
t.Fatal(err)
}
conf.Request.Body = io.NopCloser(&body)
conf.Request.ContentLength = int64(body.Len())
if conf.Request.Header.Get(echo.HeaderContentType) == "" {
conf.Request.Header.Set(echo.HeaderContentType, mw.FormDataContentType())
}
if isDefaultRequest {
conf.Request.Method = http.MethodPost
}
} else if conf.JSONBody != nil {
body := bytes.NewReader(conf.JSONBody)
conf.Request.Body = io.NopCloser(body)
conf.Request.ContentLength = int64(body.Len())
if conf.Request.Header.Get(echo.HeaderContentType) == "" {
conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
}
if isDefaultRequest {
conf.Request.Method = http.MethodPost
}
}
ec := echo.NewContext(conf.Request, conf.Response, echo.New())
if conf.RouteInfo == nil {
conf.RouteInfo = &echo.RouteInfo{
Name: "",
Method: conf.Request.Method,
Path: "/test",
Parameters: []string{},
}
for _, p := range conf.PathValues {
conf.RouteInfo.Parameters = append(conf.RouteInfo.Parameters, p.Name)
}
}
ec.InitializeRoute(conf.RouteInfo, &conf.PathValues)
return ec, conf.Response
}
// ServeWithHandler serves ContextConfig with given handler and returns httptest.ResponseRecorder for response checking
func (conf ContextConfig) ServeWithHandler(t *testing.T, handler echo.HandlerFunc, opts ...any) *httptest.ResponseRecorder {
c, rec := conf.ToContextRecorder(t)
errHandler := echo.DefaultHTTPErrorHandler(false)
for _, opt := range opts {
switch o := opt.(type) {
case echo.HTTPErrorHandler:
errHandler = o
}
}
err := handler(c)
if err != nil {
errHandler(c, err)
}
return rec
}
================================================
FILE: echotest/context_external_test.go
================================================
package echotest_test
import (
"net/http"
"testing"
"github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/echotest"
"github.com/stretchr/testify/assert"
)
func TestToContext_JSONBody(t *testing.T) {
c := echotest.ContextConfig{
JSONBody: echotest.LoadBytes(t, "testdata/test.json"),
}.ToContext(t)
payload := struct {
Field string `json:"field"`
}{}
if err := c.Bind(&payload); err != nil {
t.Fatal(err)
}
assert.Equal(t, "value", payload.Field)
assert.Equal(t, http.MethodPost, c.Request().Method)
assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType))
}
================================================
FILE: echotest/context_test.go
================================================
package echotest
import (
"net/http"
"net/url"
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestServeWithHandler(t *testing.T) {
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, c.QueryParam("key"))
}
testConf := ContextConfig{
QueryValues: url.Values{"key": []string{"value"}},
}
resp := testConf.ServeWithHandler(t, handler)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "value", resp.Body.String())
}
func TestServeWithHandler_error(t *testing.T) {
handler := func(c *echo.Context) error {
return echo.NewHTTPError(http.StatusBadRequest, "something went wrong")
}
testConf := ContextConfig{
QueryValues: url.Values{"key": []string{"value"}},
}
customErrHandler := echo.DefaultHTTPErrorHandler(true)
resp := testConf.ServeWithHandler(t, handler, customErrHandler)
assert.Equal(t, http.StatusBadRequest, resp.Code)
assert.Equal(t, `{"message":"something went wrong"}`+"\n", resp.Body.String())
}
func TestToContext_QueryValues(t *testing.T) {
testConf := ContextConfig{
QueryValues: url.Values{"t": []string{"2006-01-02"}},
}
c := testConf.ToContext(t)
v, err := echo.QueryParam[string](c, "t")
assert.NoError(t, err)
assert.Equal(t, "2006-01-02", v)
}
func TestToContext_Headers(t *testing.T) {
testConf := ContextConfig{
Headers: http.Header{echo.HeaderXRequestID: []string{"ABC"}},
}
c := testConf.ToContext(t)
id := c.Request().Header.Get(echo.HeaderXRequestID)
assert.Equal(t, "ABC", id)
}
func TestToContext_PathValues(t *testing.T) {
testConf := ContextConfig{
PathValues: echo.PathValues{{
Name: "key",
Value: "value",
}},
}
c := testConf.ToContext(t)
key := c.Param("key")
assert.Equal(t, "value", key)
}
func TestToContext_RouteInfo(t *testing.T) {
testConf := ContextConfig{
RouteInfo: &echo.RouteInfo{
Name: "my_route",
Method: http.MethodGet,
Path: "/:id",
Parameters: []string{"id"},
},
}
c := testConf.ToContext(t)
ri := c.RouteInfo()
assert.Equal(t, echo.RouteInfo{
Name: "my_route",
Method: http.MethodGet,
Path: "/:id",
Parameters: []string{"id"},
}, ri)
}
func TestToContext_FormValues(t *testing.T) {
testConf := ContextConfig{
FormValues: url.Values{"key": []string{"value"}},
}
c := testConf.ToContext(t)
assert.Equal(t, "value", c.FormValue("key"))
assert.Equal(t, http.MethodPost, c.Request().Method)
assert.Equal(t, echo.MIMEApplicationForm, c.Request().Header.Get(echo.HeaderContentType))
}
func TestToContext_MultipartForm(t *testing.T) {
testConf := ContextConfig{
MultipartForm: &MultipartForm{
Fields: map[string]string{
"key": "value",
},
Files: []MultipartFormFile{
{
Fieldname: "file",
Filename: "test.json",
Content: LoadBytes(t, "testdata/test.json"),
},
},
},
}
c := testConf.ToContext(t)
assert.Equal(t, "value", c.FormValue("key"))
assert.Equal(t, http.MethodPost, c.Request().Method)
assert.Equal(t, true, strings.HasPrefix(c.Request().Header.Get(echo.HeaderContentType), "multipart/form-data; boundary="))
fv, err := c.FormFile("file")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, "test.json", fv.Filename)
assert.Equal(t, int64(23), fv.Size)
}
func TestToContext_JSONBody(t *testing.T) {
testConf := ContextConfig{
JSONBody: LoadBytes(t, "testdata/test.json"),
}
c := testConf.ToContext(t)
payload := struct {
Field string `json:"field"`
}{}
if err := c.Bind(&payload); err != nil {
t.Fatal(err)
}
assert.Equal(t, "value", payload.Field)
assert.Equal(t, http.MethodPost, c.Request().Method)
assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType))
}
================================================
FILE: echotest/reader.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echotest
import (
"os"
"path/filepath"
"runtime"
"testing"
)
type loadBytesOpts func([]byte) []byte
// TrimNewlineEnd instructs LoadBytes to remove `\n` from the end of loaded file.
func TrimNewlineEnd(bytes []byte) []byte {
bLen := len(bytes)
if bLen > 1 && bytes[bLen-1] == '\n' {
bytes = bytes[:bLen-1]
}
return bytes
}
// LoadBytes is helper to load file contents relative to current (where test file is) package
// directory.
func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte {
bytes := loadBytes(t, name, 2)
for _, f := range opts {
bytes = f(bytes)
}
return bytes
}
func loadBytes(t *testing.T, name string, callDepth int) []byte {
_, b, _, _ := runtime.Caller(callDepth)
basepath := filepath.Dir(b)
path := filepath.Join(basepath, name) // relative path
bytes, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
return bytes[:]
}
================================================
FILE: echotest/reader_external_test.go
================================================
package echotest_test
import (
"strings"
"testing"
"github.com/labstack/echo/v5/echotest"
"github.com/stretchr/testify/assert"
)
const testJSONContent = `{
"field": "value"
}`
func TestLoadBytesOK(t *testing.T) {
data := echotest.LoadBytes(t, "testdata/test.json")
assert.Equal(t, []byte(testJSONContent+"\n"), data)
}
func TestLoadBytes_custom(t *testing.T) {
data := echotest.LoadBytes(t, "testdata/test.json", func(bytes []byte) []byte {
return []byte(strings.ToUpper(string(bytes)))
})
assert.Equal(t, []byte(strings.ToUpper(testJSONContent)+"\n"), data)
}
================================================
FILE: echotest/reader_test.go
================================================
package echotest
import (
"testing"
"github.com/stretchr/testify/assert"
)
const testJSONContent = `{
"field": "value"
}`
func TestLoadBytesOK(t *testing.T) {
data := LoadBytes(t, "testdata/test.json")
assert.Equal(t, []byte(testJSONContent+"\n"), data)
}
func TestLoadBytesOK_TrimNewlineEnd(t *testing.T) {
data := LoadBytes(t, "testdata/test.json", TrimNewlineEnd)
assert.Equal(t, []byte(testJSONContent), data)
}
================================================
FILE: echotest/testdata/test.json
================================================
{
"field": "value"
}
================================================
FILE: go.mod
================================================
module github.com/labstack/echo/v5
go 1.25.0
require (
github.com/stretchr/testify v1.11.1
golang.org/x/net v0.49.0
golang.org/x/time v0.14.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/text v0.33.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
================================================
FILE: go.sum
================================================
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
================================================
FILE: group.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"io/fs"
"net/http"
)
// Group is a set of sub-routes for a specified route. It can be used for inner
// routes that share a common middleware or functionality that should be separate
// from the parent echo instance while still inheriting from it.
type Group struct {
echo *Echo
prefix string
middleware []MiddlewareFunc
}
// Use implements `Echo#Use()` for sub-routes within the Group.
// Group middlewares are not executed on request when there is no matching route found.
func (g *Group) Use(middleware ...MiddlewareFunc) {
g.middleware = append(g.middleware, middleware...)
}
// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error.
func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodConnect, path, h, m...)
}
// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error.
func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodDelete, path, h, m...)
}
// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error.
func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodGet, path, h, m...)
}
// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error.
func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodHead, path, h, m...)
}
// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error.
func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodOptions, path, h, m...)
}
// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error.
func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPatch, path, h, m...)
}
// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error.
func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPost, path, h, m...)
}
// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error.
func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPut, path, h, m...)
}
// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error.
func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodTrace, path, h, m...)
}
// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error.
func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
return g.Add(RouteAny, path, handler, middleware...)
}
// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error.
func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
errs := make([]error, 0)
ris := make(Routes, 0)
for _, m := range methods {
ri, err := g.AddRoute(Route{
Method: m,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
errs = append(errs, err)
continue
}
ris = append(ris, ri)
}
if len(errs) > 0 {
panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ris
}
// Group creates a new sub-group with prefix and optional sub-group-level middleware.
// Important! Group middlewares are only executed in case there was exact route match and not
// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add
// a catch-all route `/*` for the group which handler returns always 404
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...)
m = append(m, middleware...)
sg = g.echo.Group(g.prefix+prefix, m...)
return
}
// Static implements `Echo#Static()` for sub-routes within the Group.
func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo {
subFs := MustSubFS(g.echo.Filesystem, fsRoot)
return g.StaticFS(pathPrefix, subFs, middleware...)
}
// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group.
//
// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
// including `assets/images` as their prefix.
func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo {
return g.Add(
http.MethodGet,
pathPrefix+"*",
StaticDirectoryHandler(filesystem, false),
middleware...,
)
}
// FileFS implements `Echo#FileFS()` for sub-routes within the Group.
func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo {
return g.GET(path, StaticFileHandler(file, filesystem), m...)
}
// File implements `Echo#File()` for sub-routes within the Group. Panics on error.
func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
handler := func(c *Context) error {
return c.File(file)
}
return g.Add(http.MethodGet, path, handler, middleware...)
}
// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group.
//
// Example: `g.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })`
func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(RouteNotFound, path, h, m...)
}
// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error.
func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
ri, err := g.AddRoute(Route{
Method: method,
Path: path,
Handler: handler,
Middlewares: middleware,
})
if err != nil {
panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
return ri
}
// AddRoute registers a new Routable with Router
func (g *Group) AddRoute(route Route) (RouteInfo, error) {
// Combine middleware into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls.
groupRoute := route.WithPrefix(g.prefix, append([]MiddlewareFunc{}, g.middleware...))
return g.echo.add(groupRoute)
}
================================================
FILE: group_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"io/fs"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) {
e := New()
called := false
mw := func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
called = true
return c.NoContent(http.StatusTeapot)
}
}
// even though group has middleware it will not be executed when there are no routes under that group
_ = e.Group("/group", mw)
status, body := request(http.MethodGet, "/group/nope", e)
assert.Equal(t, http.StatusNotFound, status)
assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
assert.False(t, called)
}
func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) {
e := New()
called := false
mw := func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
called = true
return c.NoContent(http.StatusTeapot)
}
}
// even though group has middleware and routes when we have no match on some route the middlewares for that
// group will not be executed
g := e.Group("/group", mw)
g.GET("/yes", handlerFunc)
status, body := request(http.MethodGet, "/group/nope", e)
assert.Equal(t, http.StatusNotFound, status)
assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
assert.False(t, called)
}
func TestGroup_multiLevelGroup(t *testing.T) {
e := New()
api := e.Group("/api")
users := api.Group("/users")
users.GET("/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
status, body := request(http.MethodGet, "/api/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroupFile(t *testing.T) {
e := New()
g := e.Group("/group")
g.File("/walle", "_fixture/images/walle.png")
expectedData, err := os.ReadFile("_fixture/images/walle.png")
assert.Nil(t, err)
req := httptest.NewRequest(http.MethodGet, "/group/walle", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, expectedData, rec.Body.Bytes())
}
func TestGroupRouteMiddleware(t *testing.T) {
// Ensure middleware slices are not re-used
e := New()
g := e.Group("/group")
h := func(*Context) error { return nil }
m1 := func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
return next(c)
}
}
m2 := func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
return next(c)
}
}
m3 := func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
return next(c)
}
}
m4 := func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
return c.NoContent(404)
}
}
m5 := func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
return c.NoContent(405)
}
}
g.Use(m1, m2, m3)
g.GET("/404", h, m4)
g.GET("/405", h, m5)
c, _ := request(http.MethodGet, "/group/404", e)
assert.Equal(t, 404, c)
c, _ = request(http.MethodGet, "/group/405", e)
assert.Equal(t, 405, c)
}
func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
// Ensure middleware and match any routes do not conflict
e := New()
g := e.Group("/group")
m1 := func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
return next(c)
}
}
m2 := func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
return c.String(http.StatusOK, c.RouteInfo().Path)
}
}
h := func(c *Context) error {
return c.String(http.StatusOK, c.RouteInfo().Path)
}
g.Use(m1)
g.GET("/help", h, m2)
g.GET("/*", h, m2)
g.GET("", h, m2)
e.GET("unrelated", h, m2)
e.GET("*", h, m2)
_, m := request(http.MethodGet, "/group/help", e)
assert.Equal(t, "/group/help", m)
_, m = request(http.MethodGet, "/group/help/other", e)
assert.Equal(t, "/group/*", m)
_, m = request(http.MethodGet, "/group/404", e)
assert.Equal(t, "/group/*", m)
_, m = request(http.MethodGet, "/group", e)
assert.Equal(t, "/group", m)
_, m = request(http.MethodGet, "/other", e)
assert.Equal(t, "/*", m)
_, m = request(http.MethodGet, "/", e)
assert.Equal(t, "/*", m)
}
func TestGroup_CONNECT(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.CONNECT("/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodConnect, ri.Method)
assert.Equal(t, "/users/activate", ri.Path)
assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodConnect, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_DELETE(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.DELETE("/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodDelete, ri.Method)
assert.Equal(t, "/users/activate", ri.Path)
assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodDelete, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_HEAD(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.HEAD("/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodHead, ri.Method)
assert.Equal(t, "/users/activate", ri.Path)
assert.Equal(t, http.MethodHead+":/users/activate", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodHead, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_OPTIONS(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.OPTIONS("/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodOptions, ri.Method)
assert.Equal(t, "/users/activate", ri.Path)
assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodOptions, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_PATCH(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.PATCH("/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPatch, ri.Method)
assert.Equal(t, "/users/activate", ri.Path)
assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodPatch, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_POST(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.POST("/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPost, ri.Method)
assert.Equal(t, "/users/activate", ri.Path)
assert.Equal(t, http.MethodPost+":/users/activate", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodPost, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_PUT(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.PUT("/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodPut, ri.Method)
assert.Equal(t, "/users/activate", ri.Path)
assert.Equal(t, http.MethodPut+":/users/activate", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodPut, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_TRACE(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.TRACE("/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Equal(t, http.MethodTrace, ri.Method)
assert.Equal(t, "/users/activate", ri.Path)
assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodTrace, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
func TestGroup_RouteNotFound(t *testing.T) {
var testCases = []struct {
expectRoute any
name string
whenURL string
expectCode int
}{
{
name: "404, route to static not found handler /group/a/c/xx",
whenURL: "/group/a/c/xx",
expectRoute: "GET /group/a/c/xx",
expectCode: http.StatusNotFound,
},
{
name: "404, route to path param not found handler /group/a/:file",
whenURL: "/group/a/echo.exe",
expectRoute: "GET /group/a/:file",
expectCode: http.StatusNotFound,
},
{
name: "404, route to any not found handler /group/*",
whenURL: "/group/b/echo.exe",
expectRoute: "GET /group/*",
expectCode: http.StatusNotFound,
},
{
name: "200, route /group/a/c/df to /group/a/c/df",
whenURL: "/group/a/c/df",
expectRoute: "GET /group/a/c/df",
expectCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
g := e.Group("/group")
okHandler := func(c *Context) error {
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
}
notFoundHandler := func(c *Context) error {
return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
}
g.GET("/", okHandler)
g.GET("/a/c/df", okHandler)
g.GET("/a/b*", okHandler)
g.PUT("/*", okHandler)
g.RouteNotFound("/a/c/xx", notFoundHandler) // static
g.RouteNotFound("/a/:file", notFoundHandler) // param
g.RouteNotFound("/*", notFoundHandler) // any
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectCode, rec.Code)
assert.Equal(t, tc.expectRoute, rec.Body.String())
})
}
}
func TestGroup_Any(t *testing.T) {
e := New()
users := e.Group("/users")
ri := users.Any("/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK from ANY")
})
assert.Equal(t, RouteAny, ri.Method)
assert.Equal(t, "/users/activate", ri.Path)
assert.Equal(t, RouteAny+":/users/activate", ri.Name)
assert.Nil(t, ri.Parameters)
status, body := request(http.MethodTrace, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK from ANY`, body)
}
func TestGroup_Match(t *testing.T) {
e := New()
myMethods := []string{http.MethodGet, http.MethodPost}
users := e.Group("/users")
ris := users.Match(myMethods, "/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
assert.Len(t, ris, 2)
for _, m := range myMethods {
status, body := request(m, "/users/activate", e)
assert.Equal(t, http.StatusTeapot, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_MatchWithErrors(t *testing.T) {
e := New()
users := e.Group("/users")
users.GET("/activate", func(c *Context) error {
return c.String(http.StatusOK, "OK")
})
myMethods := []string{http.MethodGet, http.MethodPost}
errs := func() (errs []error) {
defer func() {
if r := recover(); r != nil {
if tmpErr, ok := r.([]error); ok {
errs = tmpErr
return
}
panic(r)
}
}()
users.Match(myMethods, "/activate", func(c *Context) error {
return c.String(http.StatusTeapot, "OK")
})
return nil
}()
assert.Len(t, errs, 1)
assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
for _, m := range myMethods {
status, body := request(m, "/users/activate", e)
expect := http.StatusTeapot
if m == http.MethodGet {
expect = http.StatusOK
}
assert.Equal(t, expect, status)
assert.Equal(t, `OK`, body)
}
}
func TestGroup_Static(t *testing.T) {
e := New()
g := e.Group("/books")
ri := g.Static("/download", "_fixture")
assert.Equal(t, http.MethodGet, ri.Method)
assert.Equal(t, "/books/download*", ri.Path)
assert.Equal(t, "GET:/books/download*", ri.Name)
assert.Equal(t, []string{"*"}, ri.Parameters)
req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
body := rec.Body.String()
assert.True(t, strings.HasPrefix(body, ""))
}
func TestGroup_StaticMultiTest(t *testing.T) {
var testCases = []struct {
name string
givenPrefix string
givenRoot string
whenURL string
expectHeaderLocation string
expectBodyStartsWith string
expectBodyNotContains string
expectStatus int
}{
{
name: "ok",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/test/images/walle.png",
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "ok, without prefix",
givenPrefix: "",
givenRoot: "_fixture/images",
whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png`
expectStatus: http.StatusOK,
expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
},
{
name: "nok, without prefix does not serve dir index",
givenPrefix: "",
givenRoot: "_fixture/images",
whenURL: "/test/", // `/test` + `*` creates route `/test*`
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "No file",
givenPrefix: "/images",
givenRoot: "_fixture/scripts",
whenURL: "/test/images/bolt.png",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory",
givenPrefix: "/images",
givenRoot: "_fixture/images",
whenURL: "/test/images/",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Directory Redirect",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/folder",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory Redirect with non-root path",
givenPrefix: "/static",
givenRoot: "_fixture",
whenURL: "/test/static",
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/static/",
expectBodyStartsWith: "",
},
{
name: "Prefixed directory 404 (request URL without slash)",
givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
givenRoot: "_fixture",
whenURL: "/test/folder", // no trailing slash
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "Prefixed directory redirect (without slash redirect to slash)",
givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
givenRoot: "_fixture",
whenURL: "/test/folder", // no trailing slash
expectStatus: http.StatusMovedPermanently,
expectHeaderLocation: "/test/folder/",
expectBodyStartsWith: "",
},
{
name: "Directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "",
},
{
name: "Prefixed directory with index.html (prefix ending with slash)",
givenPrefix: "/assets/",
givenRoot: "_fixture",
whenURL: "/test/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "",
},
{
name: "Prefixed directory with index.html (prefix ending without slash)",
givenPrefix: "/assets",
givenRoot: "_fixture",
whenURL: "/test/assets/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "",
},
{
name: "Sub-directory with index.html",
givenPrefix: "/",
givenRoot: "_fixture",
whenURL: "/test/folder/",
expectStatus: http.StatusOK,
expectBodyStartsWith: "",
},
{
name: "nok, URL encoded path traversal (single encoding, slash - unix separator)",
givenRoot: "_fixture/dist/public",
whenURL: "/%2e%2e%2fprivate.txt",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
expectBodyNotContains: `private file`,
},
{
name: "nok, URL encoded path traversal (single encoding, backslash - windows separator)",
givenRoot: "_fixture/dist/public",
whenURL: "/%2e%2e%5cprivate.txt",
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
expectBodyNotContains: `private file`,
},
{
name: "do not allow directory traversal (backslash - windows separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/test/..\\middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
{
name: "do not allow directory traversal (slash - unix separator)",
givenPrefix: "/",
givenRoot: "_fixture/",
whenURL: `/test/../middleware/basic_auth.go`,
expectStatus: http.StatusNotFound,
expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
g := e.Group("/test")
g.Static(tc.givenPrefix, tc.givenRoot)
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
body := rec.Body.String()
if tc.expectBodyStartsWith != "" {
assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
} else {
assert.Equal(t, "", body)
}
if tc.expectBodyNotContains != "" {
assert.NotContains(t, body, tc.expectBodyNotContains)
}
if tc.expectHeaderLocation != "" {
assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
} else {
_, ok := rec.Result().Header["Location"]
assert.False(t, ok)
}
})
}
}
func TestGroup_FileFS(t *testing.T) {
var testCases = []struct {
whenFS fs.FS
name string
whenPath string
whenFile string
givenURL string
expectStartsWith []byte
expectCode int
}{
{
name: "ok",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "walle.png",
givenURL: "/assets/walle",
expectCode: http.StatusOK,
expectStartsWith: []byte{0x89, 0x50, 0x4e},
},
{
name: "nok, requesting invalid path",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "walle.png",
givenURL: "/assets/walle.png",
expectCode: http.StatusNotFound,
expectStartsWith: []byte(`{"message":"Not Found"}`),
},
{
name: "nok, serving not existent file from filesystem",
whenPath: "/walle",
whenFS: os.DirFS("_fixture/images"),
whenFile: "not-existent.png",
givenURL: "/assets/walle",
expectCode: http.StatusNotFound,
expectStartsWith: []byte(`{"message":"Not Found"}`),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
g := e.Group("/assets")
g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectCode, rec.Code)
body := rec.Body.Bytes()
if len(body) > len(tc.expectStartsWith) {
body = body[:len(tc.expectStartsWith)]
}
assert.Equal(t, tc.expectStartsWith, body)
})
}
}
func TestGroup_StaticPanic(t *testing.T) {
var testCases = []struct {
name string
givenRoot string
}{
{
name: "panics for ../",
givenRoot: "../images",
},
{
name: "panics for /",
givenRoot: "/images",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := New()
e.Filesystem = os.DirFS("./")
g := e.Group("/assets")
assert.Panics(t, func() {
g.Static("/images", tc.givenRoot)
})
})
}
}
func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) {
var testCases = []struct {
expectBody any
name string
whenURL string
expectCode int
givenCustom404 bool
expectMiddlewareCalled bool
}{
{
name: "ok, custom 404 handler is called with middleware",
givenCustom404: true,
whenURL: "/group/test3",
expectBody: "404 GET /group/*",
expectCode: http.StatusNotFound,
expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added
},
{
name: "ok, default group 404 handler is not called with middleware",
givenCustom404: false,
whenURL: "/group/test3",
expectBody: "404 GET /*",
expectCode: http.StatusNotFound,
expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
},
{
name: "ok, (no slash) default group 404 handler is called with middleware",
givenCustom404: false,
whenURL: "/group",
expectBody: "404 GET /*",
expectCode: http.StatusNotFound,
expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
okHandler := func(c *Context) error {
return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
}
notFoundHandler := func(c *Context) error {
return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path())
}
e := New()
e.GET("/test1", okHandler)
e.RouteNotFound("/*", notFoundHandler)
g := e.Group("/group")
g.GET("/test1", okHandler)
middlewareCalled := false
g.Use(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
middlewareCalled = true
return next(c)
}
})
if tc.givenCustom404 {
g.RouteNotFound("/*", notFoundHandler)
}
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled)
assert.Equal(t, tc.expectCode, rec.Code)
assert.Equal(t, tc.expectBody, rec.Body.String())
})
}
}
================================================
FILE: httperror.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"errors"
"fmt"
"net/http"
)
// Following errors can produce HTTP status code by implementing HTTPStatusCoder interface
var (
ErrBadRequest = &httpError{http.StatusBadRequest} // 400
ErrUnauthorized = &httpError{http.StatusUnauthorized} // 401
ErrForbidden = &httpError{http.StatusForbidden} // 403
ErrNotFound = &httpError{http.StatusNotFound} // 404
ErrMethodNotAllowed = &httpError{http.StatusMethodNotAllowed} // 405
ErrRequestTimeout = &httpError{http.StatusRequestTimeout} // 408
ErrStatusRequestEntityTooLarge = &httpError{http.StatusRequestEntityTooLarge} // 413
ErrUnsupportedMediaType = &httpError{http.StatusUnsupportedMediaType} // 415
ErrTooManyRequests = &httpError{http.StatusTooManyRequests} // 429
ErrInternalServerError = &httpError{http.StatusInternalServerError} // 500
ErrBadGateway = &httpError{http.StatusBadGateway} // 502
ErrServiceUnavailable = &httpError{http.StatusServiceUnavailable} // 503
)
// Following errors fall into 500 (InternalServerError) category
var (
ErrValidatorNotRegistered = errors.New("validator not registered")
ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found")
ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
ErrInvalidListenerNetwork = errors.New("invalid listener network")
)
// HTTPStatusCoder is interface that errors can implement to produce status code for HTTP response
type HTTPStatusCoder interface {
StatusCode() int
}
// StatusCode returns status code from error if it implements HTTPStatusCoder interface.
// If error does not implement the interface it returns 0.
func StatusCode(err error) int {
var sc HTTPStatusCoder
if errors.As(err, &sc) {
return sc.StatusCode()
}
return 0
}
// ResolveResponseStatus returns the Response and HTTP status code that should be (or has been) sent for rw,
// given an optional error.
//
// This function is useful for middleware and handlers that need to figure out the HTTP status
// code to return based on the error that occurred or what was set in the response.
//
// Precedence rules:
// 1. If the response has already been committed, the committed status wins (err is ignored).
// 2. Otherwise, start with 200 OK (net/http default if WriteHeader is never called).
// 3. If the response has a non-zero suggested status, use it.
// 4. If err != nil, it overrides the suggested status:
// - StatusCode(err) if non-zero
// - otherwise 500 Internal Server Error.
func ResolveResponseStatus(rw http.ResponseWriter, err error) (resp *Response, status int) {
resp, _ = UnwrapResponse(rw)
// once committed (sent to the client), the wire status is fixed; err cannot change it.
if resp != nil && resp.Committed {
if resp.Status == 0 {
// unlikely path, but fall back to net/http implicit default if handler never calls WriteHeader
return resp, http.StatusOK
}
return resp, resp.Status
}
// net/http implicit default if handler never calls WriteHeader.
status = http.StatusOK
// suggested status written from middleware/handlers, if present.
if resp != nil && resp.Status != 0 {
status = resp.Status
}
// error overrides suggested status (matches typical Echo error-handler semantics).
if err != nil {
if s := StatusCode(err); s != 0 {
status = s
} else {
status = http.StatusInternalServerError
}
}
return resp, status
}
// NewHTTPError creates new instance of HTTPError
func NewHTTPError(code int, message string) *HTTPError {
return &HTTPError{
Code: code,
Message: message,
}
}
// HTTPError represents an error that occurred while handling a request.
type HTTPError struct {
// Code is status code for HTTP response
Code int `json:"-"`
Message string `json:"message"`
err error
}
// StatusCode returns status code for HTTP response
func (he *HTTPError) StatusCode() int {
return he.Code
}
// Error makes it compatible with `error` interface.
func (he *HTTPError) Error() string {
msg := he.Message
if msg == "" {
msg = http.StatusText(he.Code)
}
if he.err == nil {
return fmt.Sprintf("code=%d, message=%v", he.Code, msg)
}
return fmt.Sprintf("code=%d, message=%v, err=%v", he.Code, msg, he.err.Error())
}
// Wrap eturns new HTTPError with given errors wrapped inside
func (he HTTPError) Wrap(err error) error {
return &HTTPError{
Code: he.Code,
Message: he.Message,
err: err,
}
}
func (he *HTTPError) Unwrap() error {
return he.err
}
type httpError struct {
code int
}
func (he httpError) StatusCode() int {
return he.code
}
func (he httpError) Error() string {
return http.StatusText(he.code) // does not include status code
}
func (he httpError) Wrap(err error) error {
return &HTTPError{
Code: he.code,
Message: http.StatusText(he.code),
err: err,
}
}
================================================
FILE: httperror_external_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
// run tests as external package to get real feel for API
package echo_test
import (
"encoding/json"
"fmt"
"github.com/labstack/echo/v5"
"net/http"
"net/http/httptest"
)
func ExampleDefaultHTTPErrorHandler() {
e := echo.New()
e.GET("/api/endpoint", func(c *echo.Context) error {
return &apiError{
Code: http.StatusBadRequest,
Body: map[string]any{"message": "custom error"},
}
})
req := httptest.NewRequest(http.MethodGet, "/api/endpoint?err=1", nil)
resp := httptest.NewRecorder()
e.ServeHTTP(resp, req)
fmt.Printf("%d %s", resp.Code, resp.Body.String())
// Output: 400 {"error":{"message":"custom error"}}
}
type apiError struct {
Code int
Body any
}
func (e *apiError) StatusCode() int {
return e.Code
}
func (e *apiError) MarshalJSON() ([]byte, error) {
type body struct {
Error any `json:"error"`
}
return json.Marshal(body{Error: e.Body})
}
func (e *apiError) Error() string {
return http.StatusText(e.Code)
}
================================================
FILE: httperror_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"errors"
"fmt"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func TestHTTPError_StatusCode(t *testing.T) {
var err error = &HTTPError{Code: http.StatusBadRequest, Message: "my error message"}
code := 0
var sc HTTPStatusCoder
if errors.As(err, &sc) {
code = sc.StatusCode()
}
assert.Equal(t, http.StatusBadRequest, code)
}
func TestHTTPError_Error(t *testing.T) {
var testCases = []struct {
name string
error error
expect string
}{
{
name: "ok, without message",
error: &HTTPError{Code: http.StatusBadRequest},
expect: "code=400, message=Bad Request",
},
{
name: "ok, with message",
error: &HTTPError{Code: http.StatusBadRequest, Message: "my error message"},
expect: "code=400, message=my error message",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expect, tc.error.Error())
})
}
}
func TestHTTPError_WrapUnwrap(t *testing.T) {
err := &HTTPError{Code: http.StatusBadRequest, Message: "bad"}
wrapped := err.Wrap(errors.New("my_error")).(*HTTPError)
err.Code = http.StatusOK
err.Message = "changed"
assert.Equal(t, http.StatusBadRequest, wrapped.Code)
assert.Equal(t, "bad", wrapped.Message)
assert.Equal(t, errors.New("my_error"), wrapped.Unwrap())
assert.Equal(t, "code=400, message=bad, err=my_error", wrapped.Error())
}
func TestNewHTTPError(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, "bad")
err2 := &HTTPError{Code: http.StatusBadRequest, Message: "bad"}
assert.Equal(t, err2, err)
}
func TestStatusCode(t *testing.T) {
var testCases = []struct {
name string
err error
expect int
}{
{
name: "ok, HTTPError",
err: &HTTPError{Code: http.StatusNotFound},
expect: http.StatusNotFound,
},
{
name: "ok, sentinel error",
err: ErrNotFound,
expect: http.StatusNotFound,
},
{
name: "ok, wrapped HTTPError",
err: fmt.Errorf("wrapped: %w", &HTTPError{Code: http.StatusTeapot}),
expect: http.StatusTeapot,
},
{
name: "nok, normal error",
err: errors.New("error"),
expect: 0,
},
{
name: "nok, nil",
err: nil,
expect: 0,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expect, StatusCode(tc.err))
})
}
}
func TestResolveResponseStatus(t *testing.T) {
someErr := errors.New("some error")
var testCases = []struct {
name string
whenResp http.ResponseWriter
whenErr error
expectStatus int
expectResp bool
}{
{
name: "nil resp, nil err -> 200",
whenResp: nil,
whenErr: nil,
expectStatus: http.StatusOK,
expectResp: false,
},
{
name: "resp suggested status used when no error",
whenResp: &Response{Status: http.StatusCreated},
whenErr: nil,
expectStatus: http.StatusCreated,
expectResp: true,
},
{
name: "error overrides suggested status with StatusCode(err)",
whenResp: &Response{Status: http.StatusAccepted},
whenErr: ErrBadRequest,
expectStatus: http.StatusBadRequest,
expectResp: true,
},
{
name: "error overrides suggested status with 500 when StatusCode(err)==0",
whenResp: &Response{Status: http.StatusAccepted},
whenErr: ErrInternalServerError,
expectStatus: http.StatusInternalServerError,
expectResp: true,
},
{
name: "nil resp, error -> 500 when StatusCode(err)==0",
whenResp: nil,
whenErr: someErr,
expectStatus: http.StatusInternalServerError,
expectResp: false,
},
{
name: "committed response wins over error",
whenResp: &Response{Committed: true, Status: http.StatusNoContent},
whenErr: someErr,
expectStatus: http.StatusNoContent,
expectResp: true,
},
{
name: "committed response with status 0 falls back to 200 (defensive)",
whenResp: &Response{Committed: true, Status: 0},
whenErr: someErr,
expectStatus: http.StatusOK,
expectResp: true,
},
{
name: "resp with status 0 and no error -> 200",
whenResp: &Response{Status: 0},
whenErr: nil,
expectStatus: http.StatusOK,
expectResp: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resp, status := ResolveResponseStatus(tc.whenResp, tc.whenErr)
assert.Equal(t, tc.expectResp, resp != nil)
assert.Equal(t, tc.expectStatus, status)
})
}
}
================================================
FILE: ip.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"net"
"net/http"
"strings"
)
/**
By: https://github.com/tmshn (See: https://github.com/labstack/echo/pull/1478 , https://github.com/labstack/echox/pull/134 )
Source: https://echo.labstack.com/guide/ip-address/
IP address plays fundamental role in HTTP; it's used for access control, auditing, geo-based access analysis and more.
Echo provides handy method [`Context#RealIP()`](https://godoc.org/github.com/labstack/echo#Context) for that.
However, it is not trivial to retrieve the _real_ IP address from requests especially when you put L7 proxies before the application.
In such situation, _real_ IP needs to be relayed on HTTP layer from proxies to your app, but you must not trust HTTP headers unconditionally.
Otherwise, you might give someone a chance of deceiving you. **A security risk!**
To retrieve IP address reliably/securely, you must let your application be aware of the entire architecture of your infrastructure.
In Echo, this can be done by configuring `Echo#IPExtractor` appropriately.
This guides show you why and how.
> Note: if you don't set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice.
Let's start from two questions to know the right direction:
1. Do you put any HTTP (L7) proxy in front of the application?
- It includes both cloud solutions (such as AWS ALB or GCP HTTP LB) and OSS ones (such as Nginx, Envoy or Istio ingress gateway).
2. If yes, what HTTP header do your proxies use to pass client IP to the application?
## Case 1. With no proxy
If you put no proxy (e.g.: directory facing to the internet), all you need to (and have to) see is IP address from network layer.
Any HTTP header is untrustable because the clients have full control what headers to be set.
In this case, use `echo.ExtractIPDirect()`.
```go
e.IPExtractor = echo.ExtractIPDirect()
```
## Case 2. With proxies using `X-Forwarded-For` header
[`X-Forwared-For` (XFF)](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) is the popular header
to relay clients' IP addresses.
At each hop on the proxies, they append the request IP address at the end of the header.
Following example diagram illustrates this behavior.
```text
┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐
│ "Origin" │───────────>│ Proxy 1 │───────────>│ Proxy 2 │───────────>│ Your app │
│ (IP: a) │ │ (IP: b) │ │ (IP: c) │ │ │
└──────────┘ └──────────┘ └──────────┘ └──────────┘
Case 1.
XFF: "" "a" "a, b"
~~~~~~
Case 2.
XFF: "x" "x, a" "x, a, b"
~~~~~~~~~
↑ What your app will see
```
In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is
configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructure".
In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`.
In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`.
```go
e.IPExtractor = echo.ExtractIPFromXFFHeader()
```
By default, it trusts internal IP addresses (loopback, link-local unicast, private-use and unique local address
from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and
[RFC4193](https://tools.ietf.org/html/rfc4193)).
To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s.
E.g.:
```go
e.IPExtractor = echo.ExtractIPFromXFFHeader(
TrustLinkLocal(false),
TrustIPRanges(lbIPRange),
)
```
- Ref: https://godoc.org/github.com/labstack/echo#TrustOption
## Case 3. With proxies using `X-Real-IP` header
`X-Real-IP` is another HTTP header to relay clients' IP addresses, but it carries only one address unlike XFF.
If your proxies set this header, use `ExtractIPFromRealIPHeader(...TrustOption)`.
```go
e.IPExtractor = echo.ExtractIPFromRealIPHeader()
```
Again, it trusts internal IP addresses by default (loopback, link-local unicast, private-use and unique local address
from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and
[RFC4193](https://tools.ietf.org/html/rfc4193)).
To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s.
- Ref: https://godoc.org/github.com/labstack/echo#TrustOption
> **Never forget** to configure the outermost proxy (i.e.; at the edge of your infrastructure) **not to pass through incoming headers**.
> Otherwise there is a chance of fraud, as it is what clients can control.
## About default behavior
In default behavior, Echo sees all of first XFF header, X-Real-IP header and IP from network layer.
As you might already notice, after reading this article, this is not good.
Sole reason this is default is just backward compatibility.
## Private IP ranges
See: https://en.wikipedia.org/wiki/Private_network
Private IPv4 address ranges (RFC 1918):
* 10.0.0.0 – 10.255.255.255 (24-bit block)
* 172.16.0.0 – 172.31.255.255 (20-bit block)
* 192.168.0.0 – 192.168.255.255 (16-bit block)
Private IPv6 address ranges:
* fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA)
*/
type ipChecker struct {
trustExtraRanges []*net.IPNet
trustLoopback bool
trustLinkLocal bool
trustPrivateNet bool
}
// TrustOption is config for which IP address to trust
type TrustOption func(*ipChecker)
// TrustLoopback configures if you trust loopback address (default: true).
func TrustLoopback(v bool) TrustOption {
return func(c *ipChecker) {
c.trustLoopback = v
}
}
// TrustLinkLocal configures if you trust link-local address (default: true).
func TrustLinkLocal(v bool) TrustOption {
return func(c *ipChecker) {
c.trustLinkLocal = v
}
}
// TrustPrivateNet configures if you trust private network address (default: true).
func TrustPrivateNet(v bool) TrustOption {
return func(c *ipChecker) {
c.trustPrivateNet = v
}
}
// TrustIPRange add trustable IP ranges using CIDR notation.
func TrustIPRange(ipRange *net.IPNet) TrustOption {
return func(c *ipChecker) {
c.trustExtraRanges = append(c.trustExtraRanges, ipRange)
}
}
func newIPChecker(configs []TrustOption) *ipChecker {
checker := &ipChecker{trustLoopback: true, trustLinkLocal: true, trustPrivateNet: true}
for _, configure := range configs {
configure(checker)
}
return checker
}
func (c *ipChecker) trust(ip net.IP) bool {
if c.trustLoopback && ip.IsLoopback() {
return true
}
if c.trustLinkLocal && ip.IsLinkLocalUnicast() {
return true
}
if c.trustPrivateNet && ip.IsPrivate() {
return true
}
for _, trustedRange := range c.trustExtraRanges {
if trustedRange.Contains(ip) {
return true
}
}
return false
}
// IPExtractor is a function to extract IP addr from http.Request.
// Set appropriate one to Echo#IPExtractor.
// See https://echo.labstack.com/guide/ip-address for more details.
type IPExtractor func(*http.Request) string
// ExtractIPDirect extracts IP address using actual IP address.
// Use this if your server faces to internet directory (i.e.: uses no proxy).
func ExtractIPDirect() IPExtractor {
return extractIP
}
func extractIP(req *http.Request) string {
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
if net.ParseIP(req.RemoteAddr) != nil {
return req.RemoteAddr
}
return ""
}
return host
}
// ExtractIPFromRealIPHeader extracts IP address using x-real-ip header.
// Use this if you put proxy which uses this header.
func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor {
checker := newIPChecker(options)
return func(req *http.Request) string {
realIP := req.Header.Get(HeaderXRealIP)
if realIP != "" {
realIP = strings.TrimPrefix(realIP, "[")
realIP = strings.TrimSuffix(realIP, "]")
if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) {
return realIP
}
}
return extractIP(req)
}
}
// ExtractIPFromXFFHeader extracts IP address using x-forwarded-for header.
// Use this if you put proxy which uses this header.
// This returns nearest untrustable IP. If all IPs are trustable, returns furthest one (i.e.: XFF[0]).
func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor {
checker := newIPChecker(options)
return func(req *http.Request) string {
directIP := extractIP(req)
xffs := req.Header[HeaderXForwardedFor]
if len(xffs) == 0 {
return directIP
}
ips := append(strings.Split(strings.Join(xffs, ","), ","), directIP)
for i := len(ips) - 1; i >= 0; i-- {
ips[i] = strings.TrimSpace(ips[i])
ips[i] = strings.TrimPrefix(ips[i], "[")
ips[i] = strings.TrimSuffix(ips[i], "]")
ip := net.ParseIP(ips[i])
if ip == nil {
// Unable to parse IP; cannot trust entire records
return directIP
}
if !checker.trust(ip) {
return ip.String()
}
}
// All of the IPs are trusted; return first element because it is furthest from server (best effort strategy).
return strings.TrimSpace(ips[0])
}
}
================================================
FILE: ip_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"net"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
func mustParseCIDR(s string) *net.IPNet {
_, IPNet, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
return IPNet
}
func TestIPChecker_TrustOption(t *testing.T) {
var testCases = []struct {
name string
whenIP string
givenOptions []TrustOption
expect bool
}{
{
name: "ip is within trust range, trusts additional private IPV6 network",
givenOptions: []TrustOption{
TrustLoopback(false),
TrustLinkLocal(false),
TrustPrivateNet(false),
// this is private IPv6 ip
// CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48
// Address: 2001:0db8:0000:0000:0000:0000:0000:0103
// Range start: 2001:0db8:0000:0000:0000:0000:0000:0000
// Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff
TrustIPRange(mustParseCIDR("2001:db8::103/48")),
},
whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
expect: true,
},
{
name: "ip is within trust range, trusts additional private IPV6 network",
givenOptions: []TrustOption{
TrustIPRange(mustParseCIDR("2001:db8::103/48")),
},
whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
expect: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
checker := newIPChecker(tc.givenOptions)
result := checker.trust(net.ParseIP(tc.whenIP))
assert.Equal(t, tc.expect, result)
})
}
}
func TestTrustIPRange(t *testing.T) {
var testCases = []struct {
name string
givenRange string
whenIP string
expect bool
}{
{
name: "ip is within trust range, IPV6 network range",
// CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48
// Address: 2001:0db8:0000:0000:0000:0000:0000:0103
// Range start: 2001:0db8:0000:0000:0000:0000:0000:0000
// Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff
givenRange: "2001:db8::103/48",
whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
expect: true,
},
{
name: "ip is outside (upper bounds) of trust range, IPV6 network range",
givenRange: "2001:db8::103/48",
whenIP: "2001:0db8:0001:0000:0000:0000:0000:0000",
expect: false,
},
{
name: "ip is outside (lower bounds) of trust range, IPV6 network range",
givenRange: "2001:db8::103/48",
whenIP: "2001:0db7:ffff:ffff:ffff:ffff:ffff:ffff",
expect: false,
},
{
name: "ip is within trust range, IPV4 network range",
// CIDR Notation: 8.8.8.8/24
// Address: 8.8.8.8
// Range start: 8.8.8.0
// Range end: 8.8.8.255
givenRange: "8.8.8.0/24",
whenIP: "8.8.8.8",
expect: true,
},
{
name: "ip is within trust range, IPV4 network range",
// CIDR Notation: 8.8.8.8/24
// Address: 8.8.8.8
// Range start: 8.8.8.0
// Range end: 8.8.8.255
givenRange: "8.8.8.0/24",
whenIP: "8.8.8.8",
expect: true,
},
{
name: "ip is outside (upper bounds) of trust range, IPV4 network range",
givenRange: "8.8.8.0/24",
whenIP: "8.8.9.0",
expect: false,
},
{
name: "ip is outside (lower bounds) of trust range, IPV4 network range",
givenRange: "8.8.8.0/24",
whenIP: "8.8.7.255",
expect: false,
},
{
name: "public ip, trust everything in IPV4 network range",
givenRange: "0.0.0.0/0",
whenIP: "8.8.8.8",
expect: true,
},
{
name: "internal ip, trust everything in IPV4 network range",
givenRange: "0.0.0.0/0",
whenIP: "127.0.10.1",
expect: true,
},
{
name: "public ip, trust everything in IPV6 network range",
givenRange: "::/0",
whenIP: "2a00:1450:4026:805::200e",
expect: true,
},
{
name: "internal ip, trust everything in IPV6 network range",
givenRange: "::/0",
whenIP: "0:0:0:0:0:0:0:1",
expect: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cidr := mustParseCIDR(tc.givenRange)
checker := newIPChecker([]TrustOption{
TrustLoopback(false), // disable to avoid interference
TrustLinkLocal(false), // disable to avoid interference
TrustPrivateNet(false), // disable to avoid interference
TrustIPRange(cidr),
})
result := checker.trust(net.ParseIP(tc.whenIP))
assert.Equal(t, tc.expect, result)
})
}
}
func TestTrustPrivateNet(t *testing.T) {
var testCases = []struct {
name string
whenIP string
expect bool
}{
{
name: "do not trust public IPv4 address",
whenIP: "8.8.8.8",
expect: false,
},
{
name: "do not trust public IPv6 address",
whenIP: "2a00:1450:4026:805::200e",
expect: false,
},
{ // Class A: 10.0.0.0 — 10.255.255.255
name: "do not trust IPv4 just outside of class A (lower bounds)",
whenIP: "9.255.255.255",
expect: false,
},
{
name: "do not trust IPv4 just outside of class A (upper bounds)",
whenIP: "11.0.0.0",
expect: false,
},
{
name: "trust IPv4 of class A (lower bounds)",
whenIP: "10.0.0.0",
expect: true,
},
{
name: "trust IPv4 of class A (upper bounds)",
whenIP: "10.255.255.255",
expect: true,
},
{ // Class B: 172.16.0.0 — 172.31.255.255
name: "do not trust IPv4 just outside of class B (lower bounds)",
whenIP: "172.15.255.255",
expect: false,
},
{
name: "do not trust IPv4 just outside of class B (upper bounds)",
whenIP: "172.32.0.0",
expect: false,
},
{
name: "trust IPv4 of class B (lower bounds)",
whenIP: "172.16.0.0",
expect: true,
},
{
name: "trust IPv4 of class B (upper bounds)",
whenIP: "172.31.255.255",
expect: true,
},
{ // Class C: 192.168.0.0 — 192.168.255.255
name: "do not trust IPv4 just outside of class C (lower bounds)",
whenIP: "192.167.255.255",
expect: false,
},
{
name: "do not trust IPv4 just outside of class C (upper bounds)",
whenIP: "192.169.0.0",
expect: false,
},
{
name: "trust IPv4 of class C (lower bounds)",
whenIP: "192.168.0.0",
expect: true,
},
{
name: "trust IPv4 of class C (upper bounds)",
whenIP: "192.168.255.255",
expect: true,
},
{ // fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA)
// splits the address block in two equally sized halves, fc00::/8 and fd00::/8.
// https://en.wikipedia.org/wiki/Unique_local_address
name: "trust IPv6 private address",
whenIP: "fdfc:3514:2cb3:4bd5::",
expect: true,
},
{
name: "do not trust IPv6 just out of /fd (upper bounds)",
whenIP: "/fe00:0000:0000:0000:0000",
expect: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
checker := newIPChecker([]TrustOption{
TrustLoopback(false), // disable to avoid interference
TrustLinkLocal(false), // disable to avoid interference
TrustPrivateNet(true),
})
result := checker.trust(net.ParseIP(tc.whenIP))
assert.Equal(t, tc.expect, result)
})
}
}
func TestTrustLinkLocal(t *testing.T) {
var testCases = []struct {
name string
whenIP string
expect bool
}{
{
name: "trust link local IPv4 address (lower bounds)",
whenIP: "169.254.0.0",
expect: true,
},
{
name: "trust link local IPv4 address (upper bounds)",
whenIP: "169.254.255.255",
expect: true,
},
{
name: "do not trust link local IPv4 address (outside of lower bounds)",
whenIP: "169.253.255.255",
expect: false,
},
{
name: "do not trust link local IPv4 address (outside of upper bounds)",
whenIP: "169.255.0.0",
expect: false,
},
{
name: "trust link local IPv6 address ",
whenIP: "fe80::1",
expect: true,
},
{
name: "do not trust link local IPv6 address ",
whenIP: "fec0::1",
expect: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
checker := newIPChecker([]TrustOption{
TrustLoopback(false), // disable to avoid interference
TrustPrivateNet(false), // disable to avoid interference
TrustLinkLocal(true),
})
result := checker.trust(net.ParseIP(tc.whenIP))
assert.Equal(t, tc.expect, result)
})
}
}
func TestTrustLoopback(t *testing.T) {
var testCases = []struct {
name string
whenIP string
expect bool
}{
{
name: "trust IPv4 as localhost",
whenIP: "127.0.0.1",
expect: true,
},
{
name: "trust IPv6 as localhost",
whenIP: "::1",
expect: true,
},
{
name: "do not trust public ip as localhost",
whenIP: "8.8.8.8",
expect: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
checker := newIPChecker([]TrustOption{
TrustLinkLocal(false), // disable to avoid interference
TrustPrivateNet(false), // disable to avoid interference
TrustLoopback(true),
})
result := checker.trust(net.ParseIP(tc.whenIP))
assert.Equal(t, tc.expect, result)
})
}
}
func TestExtractIPDirect(t *testing.T) {
var testCases = []struct {
name string
whenRequest http.Request
expectIP string
}{
{
name: "request has no headers, extracts IP from request remote addr",
whenRequest: http.Request{
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "remote addr is IP without port, extracts IP directly",
whenRequest: http.Request{
RemoteAddr: "203.0.113.1",
},
expectIP: "203.0.113.1",
},
{
name: "remote addr is IPv6 without port, extracts IP directly",
whenRequest: http.Request{
RemoteAddr: "2001:db8::1",
},
expectIP: "2001:db8::1",
},
{
name: "remote addr is IPv6 with port",
whenRequest: http.Request{
RemoteAddr: "[2001:db8::1]:8080",
},
expectIP: "2001:db8::1",
},
{
name: "remote addr is invalid, returns empty string",
whenRequest: http.Request{
RemoteAddr: "invalid-ip-format",
},
expectIP: "",
},
{
name: "request is from external IP has X-Real-Ip header, extractor still extracts IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"203.0.113.10"},
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from internal IP and has Real-IP header, extractor still extracts internal IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"203.0.113.10"},
},
RemoteAddr: "127.0.0.1:8080",
},
expectIP: "127.0.0.1",
},
{
name: "request is from external IP and has XFF + Real-IP header, extractor still extracts external IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"203.0.113.10"},
HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from internal IP and has XFF + Real-IP header, extractor still extracts internal IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"127.0.0.1"},
HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
},
RemoteAddr: "127.0.0.1:8080",
},
expectIP: "127.0.0.1",
},
{
name: "request is from external IP and has XFF header, extractor still extracts external IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from internal IP and has XFF header, extractor still extracts internal IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
},
RemoteAddr: "127.0.0.1:8080",
},
expectIP: "127.0.0.1",
},
{
name: "request is from internal IP and has INVALID XFF header, extractor still extracts internal IP from request remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"this.is.broken.lol, 169.254.0.101"},
},
RemoteAddr: "127.0.0.1:8080",
},
expectIP: "127.0.0.1",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
extractedIP := ExtractIPDirect()(&tc.whenRequest)
assert.Equal(t, tc.expectIP, extractedIP)
})
}
}
func TestExtractIPFromRealIPHeader(t *testing.T) {
_, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
_, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64")
var testCases = []struct {
whenRequest http.Request
name string
expectIP string
givenTrustOptions []TrustOption
}{
{
name: "request has no headers, extracts IP from request remote addr",
whenRequest: http.Request{
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from external IP has INVALID external X-Real-Ip header, extract IP from remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"xxx.yyy.zzz.ccc"}, // <-- this is invalid
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted
},
RemoteAddr: "[2001:db8::113:1]:8080",
},
expectIP: "2001:db8::113:1",
},
{
name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
},
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"203.0.113.199"},
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.199",
},
{
name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64"
},
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"[2001:db8::113:199]"},
},
RemoteAddr: "[2001:db8::113:1]:8080",
},
expectIP: "2001:db8::113:199",
},
{
name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
},
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"203.0.113.199"},
HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.199",
},
{
name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64"
},
whenRequest: http.Request{
Header: http.Header{
HeaderXRealIP: []string{"[2001:db8::113:199]"},
HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything
},
RemoteAddr: "[2001:db8::113:1]:8080",
},
expectIP: "2001:db8::113:199",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
extractedIP := ExtractIPFromRealIPHeader(tc.givenTrustOptions...)(&tc.whenRequest)
assert.Equal(t, tc.expectIP, extractedIP)
})
}
}
func TestExtractIPFromXFFHeader(t *testing.T) {
_, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
_, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64")
var testCases = []struct {
whenRequest http.Request
name string
expectIP string
givenTrustOptions []TrustOption
}{
{
name: "request has no headers, extracts IP from request remote addr",
whenRequest: http.Request{
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request has INVALID external XFF header, extract IP from remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"xxx.yyy.zzz.ccc, 127.0.0.2"}, // <-- this is invalid
},
RemoteAddr: "127.0.0.1:8080",
},
expectIP: "127.0.0.1",
},
{
name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"127.0.0.3, 127.0.0.2, 127.0.0.1"},
},
RemoteAddr: "127.0.0.1:8080",
},
expectIP: "127.0.0.3",
},
{
name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"[fe80::3], [fe80::2], [fe80::1]"},
},
RemoteAddr: "[fe80::1]:8080",
},
expectIP: "fe80::3",
},
{
name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"203.0.113.199"}, // <-- this is untrusted
},
RemoteAddr: "203.0.113.1:8080",
},
expectIP: "203.0.113.1",
},
{
name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr",
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"[2001:db8::1]"}, // <-- this is untrusted
},
RemoteAddr: "[2001:db8::2]:8080",
},
expectIP: "2001:db8::2",
},
{
name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header",
givenTrustOptions: []TrustOption{
TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
},
// from request its seems that request has been proxied through 6 servers.
// 1) 203.0.1.100 (this is external IP set by 203.0.100.100 which we do not trust - could be spoofed)
// 2) 203.0.100.100 (this is outside of our network but set by 203.0.113.199 which we trust to set correct IPs)
// 3) 203.0.113.199 (we trust, for example maybe our proxy from some other office)
// 4) 192.168.1.100 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products)
// 5) 127.0.0.1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing)
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"203.0.1.100, 203.0.100.100, 203.0.113.199, 192.168.1.100"},
},
RemoteAddr: "127.0.0.1:8080", // IP of proxy upstream of our APP
},
expectIP: "203.0.100.100", // this is first trusted IP in XFF chain
},
{
name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header",
givenTrustOptions: []TrustOption{
TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64"
},
// from request its seems that request has been proxied through 6 servers.
// 1) 2001:db8:1::1:100 (this is external IP set by 2001:db8:2::100:100 which we do not trust - could be spoofed)
// 2) 2001:db8:2::100:100 (this is outside of our network but set by 2001:db8::113:199 which we trust to set correct IPs)
// 3) 2001:db8::113:199 (we trust, for example maybe our proxy from some other office)
// 4) fd12:3456:789a:1::1 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products)
// 5) fe80::1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing)
whenRequest: http.Request{
Header: http.Header{
HeaderXForwardedFor: []string{"[2001:db8:1::1:100], [2001:db8:2::100:100], [2001:db8::113:199], [fd12:3456:789a:1::1]"},
},
RemoteAddr: "[fe80::1]:8080", // IP of proxy upstream of our APP
},
expectIP: "2001:db8:2::100:100", // this is first trusted IP in XFF chain
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
extractedIP := ExtractIPFromXFFHeader(tc.givenTrustOptions...)(&tc.whenRequest)
assert.Equal(t, tc.expectIP, extractedIP)
})
}
}
================================================
FILE: json.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"encoding/json"
)
// DefaultJSONSerializer implements JSON encoding using encoding/json.
type DefaultJSONSerializer struct{}
// Serialize converts an interface into a json and writes it to the response.
// You can optionally use the indent parameter to produce pretty JSONs.
func (d DefaultJSONSerializer) Serialize(c *Context, target any, indent string) error {
enc := json.NewEncoder(c.Response())
if indent != "" {
enc.SetIndent("", indent)
}
return enc.Encode(target)
}
// Deserialize reads a JSON from a request body and converts it into an interface.
func (d DefaultJSONSerializer) Deserialize(c *Context, target any) error {
if err := json.NewDecoder(c.Request().Body).Decode(target); err != nil {
return ErrBadRequest.Wrap(err)
}
return nil
}
================================================
FILE: json_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package echo
import (
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// Note this test is deliberately simple as there's not a lot to test.
// Just need to ensure it writes JSONs. The heavy work is done by the context methods.
func TestDefaultJSONCodec_Encode(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Echo
assert.Equal(t, e, c.Echo())
// Request
assert.NotNil(t, c.Request())
// Response
assert.NotNil(t, c.Response())
//--------
// Default JSON encoder
//--------
enc := new(DefaultJSONSerializer)
err := enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, "")
if assert.NoError(t, err) {
assert.Equal(t, userJSON+"\n", rec.Body.String())
}
req = httptest.NewRequest(http.MethodPost, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
err = enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, " ")
if assert.NoError(t, err) {
assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
}
}
// Note this test is deliberately simple as there's not a lot to test.
// Just need to ensure it writes JSONs. The heavy work is done by the context methods.
func TestDefaultJSONCodec_Decode(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Echo
assert.Equal(t, e, c.Echo())
// Request
assert.NotNil(t, c.Request())
// Response
assert.NotNil(t, c.Response())
//--------
// Default JSON encoder
//--------
enc := new(DefaultJSONSerializer)
var u = user{}
err := enc.Deserialize(c, &u)
if assert.NoError(t, err) {
assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"})
}
var userUnmarshalSyntaxError = user{}
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
err = enc.Deserialize(c, &userUnmarshalSyntaxError)
assert.IsType(t, &HTTPError{}, err)
assert.EqualError(t, err, "code=400, message=Bad Request, err=invalid character 'i' looking for beginning of value")
var userUnmarshalTypeError = struct {
ID string `json:"id"`
Name string `json:"name"`
}{}
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
err = enc.Deserialize(c, &userUnmarshalTypeError)
assert.IsType(t, &HTTPError{}, err)
assert.EqualError(t, err, "code=400, message=Bad Request, err=json: cannot unmarshal number into Go struct field .id of type string")
}
================================================
FILE: middleware/DEVELOPMENT.md
================================================
# Development Guidelines for middlewares
## Best practices:
* Do not use `panic` in middleware creator functions in case of invalid configuration.
* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead
because previous middlewares up in call chain could have logic for dealing with returned errors.
* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they
want to create middleware with panics or with returning errors on configuration errors.
* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same.
================================================
FILE: middleware/basic_auth.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"bytes"
"cmp"
"encoding/base64"
"errors"
"strconv"
"strings"
"github.com/labstack/echo/v5"
)
// BasicAuthConfig defines the config for BasicAuthWithConfig middleware.
//
// SECURITY: The Validator function is responsible for securely comparing credentials.
// See BasicAuthValidator documentation for guidance on preventing timing attacks.
type BasicAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers
// this function would be called once for each header until first valid result is returned
// Required.
Validator BasicAuthValidator
// Realm is a string to define realm attribute of BasicAuthWithConfig.
// Default value "Restricted".
Realm string
// AllowedCheckLimit set how many headers are allowed to be checked. This is useful
// environments like corporate test environments with application proxies restricting
// access to environment with their own auth scheme.
// Defaults to 1.
AllowedCheckLimit uint
}
// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials.
//
// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate
// valid usernames or passwords, validator implementations MUST use constant-time
// comparison for credential checking. Use crypto/subtle.ConstantTimeCompare instead
// of standard string equality (==) or switch statements.
//
// Example of SECURE implementation:
//
// import "crypto/subtle"
//
// validator := func(c *echo.Context, username, password string) (bool, error) {
// // Fetch expected credentials from database/config
// expectedUser := "admin"
// expectedPass := "secretpassword"
//
// // Use constant-time comparison to prevent timing attacks
// userMatch := subtle.ConstantTimeCompare([]byte(username), []byte(expectedUser)) == 1
// passMatch := subtle.ConstantTimeCompare([]byte(password), []byte(expectedPass)) == 1
//
// if userMatch && passMatch {
// return true, nil
// }
// return false, nil
// }
//
// Example of INSECURE implementation (DO NOT USE):
//
// // VULNERABLE TO TIMING ATTACKS - DO NOT USE
// validator := func(c *echo.Context, username, password string) (bool, error) {
// if username == "admin" && password == "secret" { // Timing leak!
// return true, nil
// }
// return false, nil
// }
type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error)
const (
basic = "basic"
defaultRealm = "Restricted"
)
// BasicAuth returns an BasicAuth middleware.
//
// For valid credentials it calls the next handler.
// For missing or invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
return BasicAuthWithConfig(BasicAuthConfig{Validator: fn})
}
// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration
func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Validator == nil {
return nil, errors.New("echo basic-auth middleware requires a validator function")
}
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
realm := defaultRealm
if config.Realm != "" {
realm = config.Realm
}
realm = strconv.Quote(realm)
limit := cmp.Or(config.AllowedCheckLimit, 1)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
var lastError error
l := len(basic)
i := uint(0)
for _, auth := range c.Request().Header[echo.HeaderAuthorization] {
if i >= limit {
break
}
if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
continue
}
i++
// Invalid base64 shouldn't be treated as error
// instead should be treated as invalid client input
b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
if errDecode != nil {
lastError = echo.ErrBadRequest.Wrap(errDecode)
continue
}
idx := bytes.IndexByte(b, ':')
if idx >= 0 {
valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:]))
if errValidate != nil {
lastError = errValidate
} else if valid {
return next(c)
}
}
}
if lastError != nil {
return lastError
}
// Need to return `401` for browsers to pop-up login box.
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
return echo.ErrUnauthorized
}
}, nil
}
================================================
FILE: middleware/basic_auth_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"crypto/subtle"
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestBasicAuth(t *testing.T) {
validatorFunc := func(c *echo.Context, u, p string) (bool, error) {
// Use constant-time comparison to prevent timing attacks
userMatch := subtle.ConstantTimeCompare([]byte(u), []byte("joe")) == 1
passMatch := subtle.ConstantTimeCompare([]byte(p), []byte("secret")) == 1
if userMatch && passMatch {
return true, nil
}
// Special case for testing error handling
if u == "error" {
return false, errors.New(p)
}
return false, nil
}
defaultConfig := BasicAuthConfig{Validator: validatorFunc}
var testCases = []struct {
name string
givenConfig BasicAuthConfig
whenAuth []string
expectHeader string
expectErr string
}{
{
name: "ok",
givenConfig: defaultConfig,
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, multiple",
givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 2},
whenAuth: []string{
"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
basic + " NOT_BASE64",
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
},
},
{
name: "nok, multiple, valid out of limit",
givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 1},
whenAuth: []string{
"Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid_password")),
// limit only check first and should not check auth below
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
},
expectHeader: basic + ` realm="Restricted"`,
expectErr: "Unauthorized",
},
{
name: "nok, invalid Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm="Restricted"`,
expectErr: "Unauthorized",
},
{
name: "nok, not base64 Authorization header",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
expectErr: "code=400, message=Bad Request, err=illegal base64 data at input byte 3",
},
{
name: "nok, missing Authorization header",
givenConfig: defaultConfig,
expectHeader: basic + ` realm="Restricted"`,
expectErr: "Unauthorized",
},
{
name: "ok, realm",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "ok, realm, case-insensitive header scheme",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
},
{
name: "nok, realm, invalid Authorization header",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectHeader: basic + ` realm="someRealm"`,
expectErr: "Unauthorized",
},
{
name: "nok, validator func returns an error",
givenConfig: defaultConfig,
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
expectErr: "my_error",
},
{
name: "ok, skipped",
givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c *echo.Context) bool {
return true
}},
whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
config := tc.givenConfig
mw, err := config.ToMiddleware()
assert.NoError(t, err)
h := mw(func(c *echo.Context) error {
return c.String(http.StatusTeapot, "test")
})
if len(tc.whenAuth) != 0 {
for _, a := range tc.whenAuth {
req.Header.Add(echo.HeaderAuthorization, a)
}
}
err = h(c)
if tc.expectErr != "" {
assert.Equal(t, http.StatusOK, res.Code)
assert.EqualError(t, err, tc.expectErr)
} else {
assert.Equal(t, http.StatusTeapot, res.Code)
assert.NoError(t, err)
}
if tc.expectHeader != "" {
assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
}
})
}
}
func TestBasicAuth_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BasicAuth(nil)
assert.NotNil(t, mw)
})
mw := BasicAuth(func(c *echo.Context, user string, password string) (bool, error) {
return true, nil
})
assert.NotNil(t, mw)
}
func TestBasicAuthWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil})
assert.NotNil(t, mw)
})
mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c *echo.Context, user string, password string) (bool, error) {
return true, nil
}})
assert.NotNil(t, mw)
}
func TestBasicAuthRealm(t *testing.T) {
e := echo.New()
mockValidator := func(c *echo.Context, u, p string) (bool, error) {
return false, nil // Always fail to trigger WWW-Authenticate header
}
tests := []struct {
name string
realm string
expectedAuth string
}{
{
name: "Default realm",
realm: "Restricted",
expectedAuth: `basic realm="Restricted"`,
},
{
name: "Custom realm",
realm: "My API",
expectedAuth: `basic realm="My API"`,
},
{
name: "Realm with special characters",
realm: `Realm with "quotes" and \backslashes`,
expectedAuth: `basic realm="Realm with \"quotes\" and \\backslashes"`,
},
{
name: "Empty realm (falls back to default)",
realm: "",
expectedAuth: `basic realm="Restricted"`,
},
{
name: "Realm with unicode",
realm: "测试领域",
expectedAuth: `basic realm="测试领域"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
h := BasicAuthWithConfig(BasicAuthConfig{
Validator: mockValidator,
Realm: tt.realm,
})(func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
})
err := h(c)
assert.Equal(t, echo.ErrUnauthorized, err)
assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
})
}
}
================================================
FILE: middleware/body_dump.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"bufio"
"bytes"
"errors"
"io"
"net"
"net/http"
"sync"
"github.com/labstack/echo/v5"
)
// BodyDumpConfig defines the config for BodyDump middleware.
type BodyDumpConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Handler receives request, response payloads and handler error if there are any.
// Required.
Handler BodyDumpHandler
// MaxRequestBytes limits how much of the request body to dump.
// If the request body exceeds this limit, only the first MaxRequestBytes
// are dumped. The handler callback receives truncated data.
// Default: 5 * MB (5,242,880 bytes)
// Set to -1 to disable limits (not recommended in production).
MaxRequestBytes int64
// MaxResponseBytes limits how much of the response body to dump.
// If the response body exceeds this limit, only the first MaxResponseBytes
// are dumped. The handler callback receives truncated data.
// Default: 5 * MB (5,242,880 bytes)
// Set to -1 to disable limits (not recommended in production).
MaxResponseBytes int64
}
// BodyDumpHandler receives the request and response payload.
type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error)
type bodyDumpResponseWriter struct {
io.Writer
http.ResponseWriter
}
// BodyDump returns a BodyDump middleware.
//
// BodyDump middleware captures the request and response payload and calls the
// registered handler.
//
// SECURITY: By default, this limits dumped bodies to 5MB to prevent memory exhaustion
// attacks. To customize limits, use BodyDumpWithConfig. To disable limits (not recommended
// in production), explicitly set MaxRequestBytes and MaxResponseBytes to -1.
func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc {
return BodyDumpWithConfig(BodyDumpConfig{Handler: handler})
}
// BodyDumpWithConfig returns a BodyDump middleware with config.
// See: `BodyDump()`.
//
// SECURITY: If MaxRequestBytes and MaxResponseBytes are not set (zero values), they default
// to 5MB each to prevent DoS attacks via large payloads. Set them explicitly to -1 to disable
// limits if needed for your use case.
func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration
func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Handler == nil {
return nil, errors.New("echo body-dump middleware requires a handler function")
}
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
if config.MaxRequestBytes == 0 {
config.MaxRequestBytes = 5 * MB
}
if config.MaxResponseBytes == 0 {
config.MaxResponseBytes = 5 * MB
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
reqBuf := bodyDumpBufferPool.Get().(*bytes.Buffer)
reqBuf.Reset()
defer bodyDumpBufferPool.Put(reqBuf)
var bodyReader io.Reader = c.Request().Body
if config.MaxRequestBytes > 0 {
bodyReader = io.LimitReader(c.Request().Body, config.MaxRequestBytes)
}
_, readErr := io.Copy(reqBuf, bodyReader)
if readErr != nil && readErr != io.EOF {
return readErr
}
if config.MaxRequestBytes > 0 {
// Drain any remaining body data to prevent connection issues
_, _ = io.Copy(io.Discard, c.Request().Body)
_ = c.Request().Body.Close()
}
reqBody := make([]byte, reqBuf.Len())
copy(reqBody, reqBuf.Bytes())
c.Request().Body = io.NopCloser(bytes.NewReader(reqBody))
// response part
resBuf := bodyDumpBufferPool.Get().(*bytes.Buffer)
resBuf.Reset()
defer bodyDumpBufferPool.Put(resBuf)
var respWriter io.Writer
if config.MaxResponseBytes > 0 {
respWriter = &limitedWriter{
response: c.Response(),
dumpBuf: resBuf,
limit: config.MaxResponseBytes,
}
} else {
respWriter = io.MultiWriter(c.Response(), resBuf)
}
writer := &bodyDumpResponseWriter{
Writer: respWriter,
ResponseWriter: c.Response(),
}
c.SetResponse(writer)
err := next(c)
// Callback
config.Handler(c, reqBody, resBuf.Bytes(), err)
return err
}
}, nil
}
func (w *bodyDumpResponseWriter) WriteHeader(code int) {
w.ResponseWriter.WriteHeader(code)
}
func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}
func (w *bodyDumpResponseWriter) Flush() {
err := http.NewResponseController(w.ResponseWriter).Flush()
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
}
func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(w.ResponseWriter).Hijack()
}
func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
var bodyDumpBufferPool = sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
type limitedWriter struct {
response http.ResponseWriter
dumpBuf *bytes.Buffer
dumped int64
limit int64
}
func (w *limitedWriter) Write(b []byte) (n int, err error) {
// Always write full data to actual response (don't truncate client response)
n, err = w.response.Write(b)
if err != nil {
return n, err
}
// Write to dump buffer only up to limit
if w.dumped < w.limit {
remaining := w.limit - w.dumped
toDump := int64(n)
if toDump > remaining {
toDump = remaining
}
w.dumpBuf.Write(b[:toDump])
w.dumped += toDump
}
return n, nil
}
================================================
FILE: middleware/body_dump_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestBodyDump(t *testing.T) {
e := echo.New()
hw := "Hello, World!"
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
body, err := io.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
requestBody := ""
responseBody := ""
mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
requestBody = string(reqBody)
responseBody = string(resBody)
}}.ToMiddleware()
assert.NoError(t, err)
if assert.NoError(t, mw(h)(c)) {
assert.Equal(t, requestBody, hw)
assert.Equal(t, responseBody, hw)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.String())
}
}
func TestBodyDump_skipper(t *testing.T) {
e := echo.New()
isCalled := false
mw, err := BodyDumpConfig{
Skipper: func(c *echo.Context) bool {
return true
},
Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
isCalled = true
},
}.ToMiddleware()
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
return errors.New("some error")
}
err = mw(h)(c)
assert.EqualError(t, err, "some error")
assert.False(t, isCalled)
}
func TestBodyDump_fails(t *testing.T) {
e := echo.New()
hw := "Hello, World!"
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
return errors.New("some error")
}
mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.EqualError(t, err, "some error")
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestBodyDumpWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: nil,
})
assert.NotNil(t, mw)
})
assert.NotPanics(t, func() {
mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}})
assert.NotNil(t, mw)
})
}
func TestBodyDump_panic(t *testing.T) {
assert.Panics(t, func() {
mw := BodyDump(nil)
assert.NotNil(t, mw)
})
assert.NotPanics(t, func() {
BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) {})
})
}
func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) {
bdrw := bodyDumpResponseWriter{
ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush
}
assert.PanicsWithError(t, "response writer flushing is not supported", func() {
bdrw.Flush()
})
}
func TestBodyDumpResponseWriter_CanFlush(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu,
}
bdrw.Flush()
assert.Equal(t, 1, trwu.unwrapCalled)
}
func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := bodyDumpResponseWriter{
ResponseWriter: trwu,
}
result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}
func TestBodyDumpResponseWriter_CanHijack(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}
func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) {
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := bodyDumpResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}
func TestBodyDump_ReadError(t *testing.T) {
e := echo.New()
// Create a reader that fails during read
failingReader := &failingReadCloser{
data: []byte("partial data"),
failAt: 7, // Fail after 7 bytes
failWith: errors.New("connection reset"),
}
req := httptest.NewRequest(http.MethodPost, "/", failingReader)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
// This handler should not be reached if body read fails
body, _ := io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, string(body))
}
requestBodyReceived := ""
mw := BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) {
requestBodyReceived = string(reqBody)
})
err := mw(h)(c)
// Verify error is propagated
assert.Error(t, err)
assert.Contains(t, err.Error(), "connection reset")
// Verify handler was not executed (callback wouldn't have received data)
assert.Empty(t, requestBodyReceived)
}
// failingReadCloser is a helper type for testing read errors
type failingReadCloser struct {
data []byte
pos int
failAt int
failWith error
}
func (f *failingReadCloser) Read(p []byte) (n int, err error) {
if f.pos >= f.failAt {
return 0, f.failWith
}
n = copy(p, f.data[f.pos:])
f.pos += n
if f.pos >= f.failAt {
return n, f.failWith
}
return n, nil
}
func (f *failingReadCloser) Close() error {
return nil
}
func TestBodyDump_RequestWithinLimit(t *testing.T) {
e := echo.New()
requestData := "Hello, World!"
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
body, _ := io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, string(body))
}
requestBodyDumped := ""
mw, err := BodyDumpConfig{
Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
requestBodyDumped = string(reqBody)
},
MaxRequestBytes: 1 * MB, // 1MB limit
MaxResponseBytes: 1 * MB,
}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, requestData, requestBodyDumped, "Small request should be fully dumped")
assert.Equal(t, requestData, rec.Body.String(), "Handler should receive full request")
}
func TestBodyDump_RequestExceedsLimit(t *testing.T) {
e := echo.New()
// Create 2KB of data but limit to 1KB
largeData := strings.Repeat("A", 2*1024)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
body, _ := io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, string(body))
}
requestBodyDumped := ""
limit := int64(1024) // 1KB limit
mw, err := BodyDumpConfig{
Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
requestBodyDumped = string(reqBody)
},
MaxRequestBytes: limit,
MaxResponseBytes: 1 * MB,
}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, int(limit), len(requestBodyDumped), "Dumped request should be truncated to limit")
assert.Equal(t, strings.Repeat("A", 1024), requestBodyDumped, "Dumped data should match first N bytes")
// Handler should receive truncated data (what was dumped)
assert.Equal(t, strings.Repeat("A", 1024), rec.Body.String())
}
func TestBodyDump_RequestAtExactLimit(t *testing.T) {
e := echo.New()
exactData := strings.Repeat("B", 1024)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(exactData))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
body, _ := io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, string(body))
}
requestBodyDumped := ""
limit := int64(1024)
mw, err := BodyDumpConfig{
Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
requestBodyDumped = string(reqBody)
},
MaxRequestBytes: limit,
MaxResponseBytes: 1 * MB,
}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, int(limit), len(requestBodyDumped), "Exact limit should dump full data")
assert.Equal(t, exactData, requestBodyDumped)
}
func TestBodyDump_ResponseWithinLimit(t *testing.T) {
e := echo.New()
responseData := "Response data"
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
return c.String(http.StatusOK, responseData)
}
responseBodyDumped := ""
mw, err := BodyDumpConfig{
Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
responseBodyDumped = string(resBody)
},
MaxRequestBytes: 1 * MB,
MaxResponseBytes: 1 * MB,
}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, responseData, responseBodyDumped, "Small response should be fully dumped")
assert.Equal(t, responseData, rec.Body.String(), "Client should receive full response")
}
func TestBodyDump_ResponseExceedsLimit(t *testing.T) {
e := echo.New()
largeResponse := strings.Repeat("X", 2*1024) // 2KB
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
return c.String(http.StatusOK, largeResponse)
}
responseBodyDumped := ""
limit := int64(1024) // 1KB limit
mw, err := BodyDumpConfig{
Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
responseBodyDumped = string(resBody)
},
MaxRequestBytes: 1 * MB,
MaxResponseBytes: limit,
}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
// Dump should be truncated
assert.Equal(t, int(limit), len(responseBodyDumped), "Dumped response should be truncated to limit")
assert.Equal(t, strings.Repeat("X", 1024), responseBodyDumped)
// Client should still receive full response!
assert.Equal(t, largeResponse, rec.Body.String(), "Client must receive full response despite dump truncation")
}
func TestBodyDump_ClientGetsFullResponse(t *testing.T) {
e := echo.New()
// This is critical - even when dump is limited, client gets everything
largeResponse := strings.Repeat("DATA", 500) // 2KB
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
// Write response in chunks to test incremental writes
for i := 0; i < 4; i++ {
c.Response().Write([]byte(strings.Repeat("DATA", 125)))
}
return nil
}
responseBodyDumped := ""
mw, err := BodyDumpConfig{
Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
responseBodyDumped = string(resBody)
},
MaxRequestBytes: 1 * MB,
MaxResponseBytes: 512, // Very small limit
}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, 512, len(responseBodyDumped), "Dump should be limited")
assert.Equal(t, largeResponse, rec.Body.String(), "Client must get full response")
}
func TestBodyDump_BothLimitsSimultaneous(t *testing.T) {
e := echo.New()
largeRequest := strings.Repeat("Q", 2*1024)
largeResponse := strings.Repeat("R", 2*1024)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeRequest))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
io.ReadAll(c.Request().Body) // Consume request
return c.String(http.StatusOK, largeResponse)
}
requestBodyDumped := ""
responseBodyDumped := ""
limit := int64(1024)
mw, err := BodyDumpConfig{
Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
requestBodyDumped = string(reqBody)
responseBodyDumped = string(resBody)
},
MaxRequestBytes: limit,
MaxResponseBytes: limit,
}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, int(limit), len(requestBodyDumped), "Request dump should be limited")
assert.Equal(t, int(limit), len(responseBodyDumped), "Response dump should be limited")
}
func TestBodyDump_DefaultConfig(t *testing.T) {
e := echo.New()
smallData := "test"
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(smallData))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
body, _ := io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, string(body))
}
requestBodyDumped := ""
// Use default config which should have 1MB limits
config := BodyDumpConfig{}
config.Handler = func(c *echo.Context, reqBody, resBody []byte, err error) {
requestBodyDumped = string(reqBody)
}
mw, err := config.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, smallData, requestBodyDumped)
}
func TestBodyDump_LargeRequestDosPrevention(t *testing.T) {
e := echo.New()
// Simulate a very large request (10MB) that could cause OOM
largeSize := 10 * 1024 * 1024 // 10MB
largeData := strings.Repeat("M", largeSize)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
body, _ := io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, string(body))
}
requestBodyDumped := ""
limit := int64(1 * MB) // Only dump 1MB max
mw, err := BodyDumpConfig{
Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
requestBodyDumped = string(reqBody)
},
MaxRequestBytes: limit,
MaxResponseBytes: 1 * MB,
}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
// Verify only 1MB was dumped, not 10MB
assert.Equal(t, int(limit), len(requestBodyDumped), "Should only dump up to limit, preventing DoS")
assert.Less(t, len(requestBodyDumped), largeSize, "Dump should be much smaller than full request")
}
func TestBodyDump_LargeResponseDosPrevention(t *testing.T) {
e := echo.New()
// Simulate a very large response (10MB)
largeSize := 10 * 1024 * 1024 // 10MB
largeResponse := strings.Repeat("R", largeSize)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
return c.String(http.StatusOK, largeResponse)
}
responseBodyDumped := ""
limit := int64(1 * MB) // Only dump 1MB max
mw, err := BodyDumpConfig{
Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
responseBodyDumped = string(resBody)
},
MaxRequestBytes: 1 * MB,
MaxResponseBytes: limit,
}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
// Verify only 1MB was dumped, not 10MB
assert.Equal(t, int(limit), len(responseBodyDumped), "Should only dump up to limit, preventing DoS")
assert.Less(t, len(responseBodyDumped), largeSize, "Dump should be much smaller than full response")
// Client still gets full response
assert.Equal(t, largeSize, rec.Body.Len(), "Client must receive full response")
}
func BenchmarkBodyDump_WithLimit(b *testing.B) {
e := echo.New()
requestData := strings.Repeat("data", 256) // 1KB
responseData := strings.Repeat("resp", 256) // 1KB
h := func(c *echo.Context) error {
io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, responseData)
}
mw, _ := BodyDumpConfig{
Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
// Simulate logging
_ = len(reqBody) + len(resBody)
},
MaxRequestBytes: 1 * MB,
MaxResponseBytes: 1 * MB,
}.ToMiddleware()
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw(h)(c)
}
}
func BenchmarkBodyDump_BufferPooling(b *testing.B) {
e := echo.New()
requestData := strings.Repeat("x", 1024)
responseData := "response"
h := func(c *echo.Context) error {
io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, responseData)
}
mw, _ := BodyDumpConfig{
Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {},
MaxRequestBytes: 1 * MB,
MaxResponseBytes: 1 * MB,
}.ToMiddleware()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw(h)(c)
}
}
================================================
FILE: middleware/body_limit.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"io"
"net/http"
"sync"
"github.com/labstack/echo/v5"
)
// BodyLimitConfig defines the config for BodyLimitWithConfig middleware.
type BodyLimitConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// LimitBytes is maximum allowed size in bytes for a request body
LimitBytes int64
}
type limitedReader struct {
BodyLimitConfig
reader io.ReadCloser
read int64
}
// BodyLimit returns a BodyLimit middleware.
//
// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it
// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request
// header and actual content read, which makes it super secure.
func BodyLimit(limitBytes int64) echo.MiddlewareFunc {
return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes})
}
// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for
// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response.
// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which
// makes it super secure.
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration
func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
pool := sync.Pool{
New: func() any {
return &limitedReader{BodyLimitConfig: config}
},
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
// Based on content length
if req.ContentLength > config.LimitBytes {
return echo.ErrStatusRequestEntityTooLarge
}
// Based on content read
r, ok := pool.Get().(*limitedReader)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object")
}
r.Reset(req.Body)
defer pool.Put(r)
req.Body = r
return next(c)
}
}, nil
}
func (r *limitedReader) Read(b []byte) (n int, err error) {
n, err = r.reader.Read(b)
r.read += int64(n)
if r.read > r.LimitBytes {
return n, echo.ErrStatusRequestEntityTooLarge
}
return
}
func (r *limitedReader) Close() error {
return r.reader.Close()
}
func (r *limitedReader) Reset(reader io.ReadCloser) {
r.reader = reader
r.read = 0
}
================================================
FILE: middleware/body_limit_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
body, err := io.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
// Based on content length (within limit)
mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
// Based on content read (overlimit)
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
assert.NoError(t, err)
he := mw(h)(c).(echo.HTTPStatusCoder)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
assert.NoError(t, err)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Hello, World!", rec.Body.String())
// Based on content read (overlimit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
assert.NoError(t, err)
he = mw(h)(c).(echo.HTTPStatusCoder)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
}
func TestBodyLimitReader(t *testing.T) {
hw := []byte("Hello, World!")
config := BodyLimitConfig{
Skipper: DefaultSkipper,
LimitBytes: 2,
}
reader := &limitedReader{
BodyLimitConfig: config,
reader: io.NopCloser(bytes.NewReader(hw)),
}
// read all should return ErrStatusRequestEntityTooLarge
_, err := io.ReadAll(reader)
he := err.(echo.HTTPStatusCoder)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
// reset reader and read two bytes must succeed
bt := make([]byte, 2)
reader.Reset(io.NopCloser(bytes.NewReader(hw)))
n, err := reader.Read(bt)
assert.Equal(t, 2, n)
assert.Equal(t, nil, err)
}
func TestBodyLimit_skipper(t *testing.T) {
e := echo.New()
h := func(c *echo.Context) error {
body, err := io.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
mw, err := BodyLimitConfig{
Skipper: func(c *echo.Context) bool {
return true
},
LimitBytes: 2,
}.ToMiddleware()
assert.NoError(t, err)
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err = mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
func TestBodyLimitWithConfig(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
body, err := io.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB})
err := mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
func TestBodyLimit(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := func(c *echo.Context) error {
body, err := io.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
mw := BodyLimit(2 * MB)
err := mw(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}
================================================
FILE: middleware/compress.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"bufio"
"bytes"
"compress/gzip"
"errors"
"io"
"net"
"net/http"
"strings"
"sync"
"github.com/labstack/echo/v5"
)
const (
gzipScheme = "gzip"
)
// GzipConfig defines the config for Gzip middleware.
type GzipConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Gzip compression level.
// Optional. Default value -1.
Level int
// Length threshold before gzip compression is applied.
// Optional. Default value 0.
//
// Most of the time you will not need to change the default. Compressing
// a short response might increase the transmitted data because of the
// gzip format overhead. Compressing the response will also consume CPU
// and time on the server and the client (for decompressing). Depending on
// your use case such a threshold might be useful.
//
// See also:
// https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
MinLength int
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
wroteHeader bool
wroteBody bool
minLength int
minLengthExceeded bool
buffer *bytes.Buffer
code int
}
// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
func Gzip() echo.MiddlewareFunc {
return GzipWithConfig(GzipConfig{})
}
// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration
func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
return nil, errors.New("invalid gzip level")
}
if config.Level == 0 {
config.Level = -1
}
if config.MinLength < 0 {
config.MinLength = 0
}
pool := gzipCompressPool(config)
bpool := bufferPool()
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
res := c.Response()
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
i := pool.Get()
w, ok := i.(*gzip.Writer)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object")
}
rw := res
w.Reset(rw)
buf := bpool.Get().(*bytes.Buffer)
buf.Reset()
grw := &gzipResponseWriter{
Writer: w,
ResponseWriter: rw,
minLength: config.MinLength,
buffer: buf,
}
c.SetResponse(grw)
defer func() {
// There are different reasons for cases when we have not yet written response to the client and now need to do so.
// a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
// b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
if !grw.wroteBody {
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
res.Header().Del(echo.HeaderContentEncoding)
}
if grw.wroteHeader {
rw.WriteHeader(grw.code)
}
// We have to reset response to it's pristine state when
// nothing is written to body or error is returned.
// See issue #424, #407.
c.SetResponse(rw)
w.Reset(io.Discard)
} else if !grw.minLengthExceeded {
// Write uncompressed response
c.SetResponse(rw)
if grw.wroteHeader {
grw.ResponseWriter.WriteHeader(grw.code)
}
_, _ = grw.buffer.WriteTo(rw)
w.Reset(io.Discard)
}
_ = w.Close()
bpool.Put(buf)
pool.Put(w)
}()
}
return next(c)
}
}, nil
}
func (w *gzipResponseWriter) WriteHeader(code int) {
w.Header().Del(echo.HeaderContentLength) // Issue #444
w.wroteHeader = true
// Delay writing of the header until we know if we'll actually compress the response
w.code = code
}
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
if w.Header().Get(echo.HeaderContentType) == "" {
w.Header().Set(echo.HeaderContentType, http.DetectContentType(b))
}
w.wroteBody = true
if !w.minLengthExceeded {
n, err := w.buffer.Write(b)
if w.buffer.Len() >= w.minLength {
w.minLengthExceeded = true
// The minimum length is exceeded, add Content-Encoding header and write the header
w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
if w.wroteHeader {
w.ResponseWriter.WriteHeader(w.code)
}
return w.Writer.Write(w.buffer.Bytes())
}
return n, err
}
return w.Writer.Write(b)
}
func (w *gzipResponseWriter) Flush() {
if !w.minLengthExceeded {
// Enforce compression because we will not know how much more data will come
w.minLengthExceeded = true
w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
if w.wroteHeader {
w.ResponseWriter.WriteHeader(w.code)
}
_, _ = w.Writer.Write(w.buffer.Bytes())
}
if gw, ok := w.Writer.(*gzip.Writer); ok {
gw.Flush()
}
_ = http.NewResponseController(w.ResponseWriter).Flush()
}
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(w.ResponseWriter).Hijack()
}
func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
if p, ok := w.ResponseWriter.(http.Pusher); ok {
return p.Push(target, opts)
}
return http.ErrNotSupported
}
func gzipCompressPool(config GzipConfig) sync.Pool {
return sync.Pool{
New: func() any {
w, err := gzip.NewWriterLevel(io.Discard, config.Level)
if err != nil {
return err
}
return w
},
}
}
func bufferPool() sync.Pool {
return sync.Pool{
New: func() any {
b := &bytes.Buffer{}
return b
},
}
}
================================================
FILE: middleware/compress_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"bytes"
"compress/gzip"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestGzip_NoAcceptEncodingHeader(t *testing.T) {
// Skip if no Accept-Encoding header
h := Gzip()(func(c *echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, "test", rec.Body.String())
}
func TestMustGzipWithConfig_panics(t *testing.T) {
assert.Panics(t, func() {
GzipWithConfig(GzipConfig{Level: 999})
})
}
func TestGzip_AcceptEncodingHeader(t *testing.T) {
h := Gzip()(func(c *echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
r, err := gzip.NewReader(rec.Body)
assert.NoError(t, err)
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal(t, "test", buf.String())
}
func TestGzip_chunked(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
chunkChan := make(chan struct{})
waitChan := make(chan struct{})
h := Gzip()(func(c *echo.Context) error {
rc := http.NewResponseController(c.Response())
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Transfer-Encoding", "chunked")
// Write and flush the first part of the data
c.Response().Write([]byte("first\n"))
rc.Flush()
chunkChan <- struct{}{}
<-waitChan
// Write and flush the second part of the data
c.Response().Write([]byte("second\n"))
rc.Flush()
chunkChan <- struct{}{}
<-waitChan
// Write the final part of the data and return
c.Response().Write([]byte("third"))
chunkChan <- struct{}{}
return nil
})
go func() {
err := h(c)
chunkChan <- struct{}{}
assert.NoError(t, err)
}()
<-chunkChan // wait for first write
waitChan <- struct{}{}
<-chunkChan // wait for second write
waitChan <- struct{}{}
<-chunkChan // wait for final write in handler
<-chunkChan // wait for return from handler
time.Sleep(5 * time.Millisecond) // to have time for flushing
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r, err := gzip.NewReader(rec.Body)
assert.NoError(t, err)
buf := new(bytes.Buffer)
buf.ReadFrom(r)
assert.Equal(t, "first\nsecond\nthird", buf.String())
}
func TestGzip_NoContent(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Gzip()(func(c *echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
if assert.NoError(t, h(c)) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes()))
}
}
func TestGzip_Empty(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Gzip()(func(c *echo.Context) error {
return c.String(http.StatusOK, "")
})
if assert.NoError(t, h(c)) {
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType))
r, err := gzip.NewReader(rec.Body)
if assert.NoError(t, err) {
var buf bytes.Buffer
buf.ReadFrom(r)
assert.Equal(t, "", buf.String())
}
}
}
func TestGzip_ErrorReturned(t *testing.T) {
e := echo.New()
e.Use(Gzip())
e.GET("/", func(c *echo.Context) error {
return echo.ErrNotFound
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Code)
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
func TestGzipWithConfig_invalidLevel(t *testing.T) {
mw, err := GzipConfig{Level: 12}.ToMiddleware()
assert.EqualError(t, err, "invalid gzip level")
assert.Nil(t, mw)
}
// Issue #806
func TestGzipWithStatic(t *testing.T) {
e := echo.New()
e.Filesystem = os.DirFS("../")
e.Use(Gzip())
e.Static("/test", "_fixture/images")
req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// Data is written out in chunks when Content-Length == "", so only
// validate the content length if it's not set.
if cl := rec.Header().Get("Content-Length"); cl != "" {
assert.Equal(t, cl, rec.Body.Len())
}
r, err := gzip.NewReader(rec.Body)
if assert.NoError(t, err) {
defer r.Close()
want, err := os.ReadFile("../_fixture/images/walle.png")
if assert.NoError(t, err) {
buf := new(bytes.Buffer)
buf.ReadFrom(r)
assert.Equal(t, want, buf.Bytes())
}
}
}
func TestGzipWithMinLength(t *testing.T) {
e := echo.New()
// Minimal response length
e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
e.GET("/", func(c *echo.Context) error {
c.Response().Write([]byte("foobarfoobar"))
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r, err := gzip.NewReader(rec.Body)
if assert.NoError(t, err) {
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal(t, "foobarfoobar", buf.String())
}
}
func TestGzipWithMinLengthTooShort(t *testing.T) {
e := echo.New()
// Minimal response length
e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
e.GET("/", func(c *echo.Context) error {
c.Response().Write([]byte("test"))
return nil
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(t, rec.Body.String(), "test")
}
func TestGzipWithResponseWithoutBody(t *testing.T) {
e := echo.New()
e.Use(Gzip())
e.GET("/", func(c *echo.Context) error {
return c.Redirect(http.StatusMovedPermanently, "http://localhost")
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusMovedPermanently, rec.Code)
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding))
}
func TestGzipWithMinLengthChunked(t *testing.T) {
e := echo.New()
// Gzip chunked
chunkBuf := make([]byte, 5)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
var r *gzip.Reader = nil
c := e.NewContext(req, rec)
next := func(c *echo.Context) error {
rc := http.NewResponseController(c.Response())
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Transfer-Encoding", "chunked")
// Write and flush the first part of the data
c.Response().Write([]byte("test\n"))
rc.Flush()
// Read the first part of the data
assert.True(t, rec.Flushed)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
var err error
r, err = gzip.NewReader(rec.Body)
assert.NoError(t, err)
_, err = io.ReadFull(r, chunkBuf)
assert.NoError(t, err)
assert.Equal(t, "test\n", string(chunkBuf))
// Write and flush the second part of the data
c.Response().Write([]byte("test\n"))
rc.Flush()
_, err = io.ReadFull(r, chunkBuf)
assert.NoError(t, err)
assert.Equal(t, "test\n", string(chunkBuf))
// Write the final part of the data and return
c.Response().Write([]byte("test"))
return nil
}
err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c)
assert.NoError(t, err)
assert.NotNil(t, r)
buf := new(bytes.Buffer)
buf.ReadFrom(r)
assert.Equal(t, "test", buf.String())
r.Close()
}
func TestGzipWithMinLengthNoContent(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c *echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
if assert.NoError(t, h(c)) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes()))
}
}
func TestGzipResponseWriter_CanUnwrap(t *testing.T) {
trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: trwu,
}
result := bdrw.Unwrap()
assert.Equal(t, trwu, result)
}
func TestGzipResponseWriter_CanHijack(t *testing.T) {
trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "can hijack")
}
func TestGzipResponseWriter_CanNotHijack(t *testing.T) {
trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
bdrw := gzipResponseWriter{
ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
_, _, err := bdrw.Hijack()
assert.EqualError(t, err, "feature not supported")
}
func BenchmarkGzip(b *testing.B) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
h := Gzip()(func(c *echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Gzip
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h(c)
}
}
================================================
FILE: middleware/context_timeout.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"context"
"errors"
"time"
"github.com/labstack/echo/v5"
)
// ContextTimeoutConfig defines the config for ContextTimeout middleware.
type ContextTimeoutConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// ErrorHandler is a function when error arises in middeware execution.
ErrorHandler func(c *echo.Context, err error) error
// Timeout configures a timeout for the middleware
Timeout time.Duration
}
// ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client
// when underlying method returns context.DeadlineExceeded error.
func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc {
return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout})
}
// ContextTimeoutWithConfig returns a Timeout middleware with config.
func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts Config to middleware.
func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Timeout == 0 {
return nil, errors.New("timeout must be set")
}
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
if config.ErrorHandler == nil {
config.ErrorHandler = func(c *echo.Context, err error) error {
if err != nil && errors.Is(err, context.DeadlineExceeded) {
return echo.ErrServiceUnavailable.Wrap(err)
}
return err
}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
defer cancel()
c.SetRequest(c.Request().WithContext(timeoutContext))
if err := next(c); err != nil {
return config.ErrorHandler(c, err)
}
return nil
}
}, nil
}
================================================
FILE: middleware/context_timeout_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"context"
"errors"
"github.com/labstack/echo/v5"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestContextTimeoutSkipper(t *testing.T) {
t.Parallel()
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
Skipper: func(context *echo.Context) bool {
return true
},
Timeout: 10 * time.Millisecond,
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
err := m(func(c *echo.Context) error {
if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
return err
}
return errors.New("response from handler")
})(c)
// if not skipped we would have not returned error due context timeout logic
assert.EqualError(t, err, "response from handler")
}
func TestContextTimeoutWithTimeout0(t *testing.T) {
t.Parallel()
assert.Panics(t, func() {
ContextTimeout(time.Duration(0))
})
}
func TestContextTimeoutErrorOutInHandler(t *testing.T) {
t.Parallel()
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 10 * time.Millisecond,
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
rec.Code = 1 // we want to be sure that even 200 will not be sent
err := m(func(c *echo.Context) error {
// this error must not be written to the client response. Middlewares upstream of timeout middleware must be able
// to handle returned error and this can be done only then handler has not yet committed (written status code)
// the response.
return echo.NewHTTPError(http.StatusTeapot, "err")
})(c)
assert.Error(t, err)
assert.EqualError(t, err, "code=418, message=err")
assert.Equal(t, 1, rec.Code)
assert.Equal(t, "", rec.Body.String())
}
func TestContextTimeoutSuccessfulRequest(t *testing.T) {
t.Parallel()
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 10 * time.Millisecond,
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
err := m(func(c *echo.Context) error {
return c.JSON(http.StatusCreated, map[string]string{"data": "ok"})
})(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusCreated, rec.Code)
assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String())
}
func TestContextTimeoutTestRequestClone(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
// Timeout has to be defined or the whole flow for timeout middleware will be skipped
Timeout: 1 * time.Second,
})
e := echo.New()
c := e.NewContext(req, rec)
err := m(func(c *echo.Context) error {
// Cookie test
cookie, err := c.Request().Cookie("cookie")
if assert.NoError(t, err) {
assert.EqualValues(t, "cookie", cookie.Name)
assert.EqualValues(t, "value", cookie.Value)
}
// Form values
if assert.NoError(t, c.Request().ParseForm()) {
assert.EqualValues(t, "value", c.Request().FormValue("form"))
}
// Query string
assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
return nil
})(c)
assert.NoError(t, err)
}
func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) {
t.Parallel()
timeout := 10 * time.Millisecond
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
Timeout: timeout,
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
err := m(func(c *echo.Context) error {
if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil {
return err
}
return c.String(http.StatusOK, "Hello, World!")
})(c)
assert.Error(t, err)
if assert.IsType(t, &echo.HTTPError{}, err) {
assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message)
}
}
func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) {
t.Parallel()
timeoutErrorHandler := func(c *echo.Context, err error) error {
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return &echo.HTTPError{
Code: http.StatusServiceUnavailable,
Message: "Timeout! change me",
}
}
return err
}
return nil
}
timeout := 50 * time.Millisecond
m := ContextTimeoutWithConfig(ContextTimeoutConfig{
Timeout: timeout,
ErrorHandler: timeoutErrorHandler,
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e := echo.New()
c := e.NewContext(req, rec)
err := m(func(c *echo.Context) error {
// NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order
// for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky.
if err := sleepWithContext(c.Request().Context(), 100*time.Millisecond); err != nil {
return err
}
// The Request Context should have a Deadline set by http.ContextTimeoutHandler
if _, ok := c.Request().Context().Deadline(); !ok {
assert.Fail(t, "No timeout set on Request Context")
}
return c.String(http.StatusOK, "Hello, World!")
})(c)
assert.IsType(t, &echo.HTTPError{}, err)
assert.Error(t, err)
assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message)
}
func sleepWithContext(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d)
defer func() {
_ = timer.Stop()
}()
select {
case <-ctx.Done():
return context.DeadlineExceeded
case <-timer.C:
return nil
}
}
================================================
FILE: middleware/cors.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/labstack/echo/v5"
)
// CORSConfig defines the config for CORS middleware.
type CORSConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// AllowOrigins determines the value of the Access-Control-Allow-Origin
// response header. This header defines a list of origins that may access the
// resource.
//
// Origin consist of following parts: `scheme + "://" + host + optional ":" + port`
// Wildcard can be used, but has to be set explicitly []string{"*"}
// Example: `https://example.com`, `http://example.com:8080`, `*`
//
// Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
//
// Mandatory.
AllowOrigins []string
// UnsafeAllowOriginFunc is an optional custom function to validate the origin. It takes the
// origin as an argument and returns
// - string, allowed origin
// - bool, true if allowed or false otherwise.
// - error, if an error is returned, it is returned immediately by the handler.
// If this option is set, AllowOrigins is ignored.
//
// Security: use extreme caution when handling the origin, and carefully
// validate any logic. Remember that attackers may register hostile (sub)domain names.
// See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// Sub-domain checks example:
// UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) {
// if strings.HasSuffix(origin, ".example.com") {
// return origin, true, nil
// }
// return "", false, nil
// },
//
// Optional.
UnsafeAllowOriginFunc func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error)
// AllowMethods determines the value of the Access-Control-Allow-Methods
// response header. This header specified the list of methods allowed when
// accessing the resource. This is used in response to a preflight request.
//
// Optional. Default value DefaultCORSConfig.AllowMethods.
// If `allowMethods` is left empty, this middleware will fill for preflight
// request `Access-Control-Allow-Methods` header value
// from `Allow` header that echo.Router set into context.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
AllowMethods []string
// AllowHeaders determines the value of the Access-Control-Allow-Headers
// response header. This header is used in response to a preflight request to
// indicate which HTTP headers can be used when making the actual request.
//
// Optional. Defaults to empty list. No domains allowed for CORS.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
AllowHeaders []string
// AllowCredentials determines the value of the
// Access-Control-Allow-Credentials response header. This header indicates
// whether or not the response to the request can be exposed when the
// credentials mode (Request.credentials) is true. When used as part of a
// response to a preflight request, this indicates whether or not the actual
// request can be made using credentials. See also
// [MDN: Access-Control-Allow-Credentials].
//
// Optional. Default value false, in which case the header is not set.
//
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
// See "Exploiting CORS misconfigurations for Bitcoins and bounties",
// https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
AllowCredentials bool
// ExposeHeaders determines the value of Access-Control-Expose-Headers, which
// defines a list of headers that clients are allowed to access.
//
// Optional. Default value []string{}, in which case the header is not set.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header
ExposeHeaders []string
// MaxAge determines the value of the Access-Control-Max-Age response header.
// This header indicates how long (in seconds) the results of a preflight
// request can be cached.
// The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response.
//
// Optional. Default value 0 - meaning header is not sent.
//
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
MaxAge int
}
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
// See also [MDN: Cross-Origin Resource Sharing (CORS)].
//
// Origin consist of following parts: `scheme + "://" + host + optional ":" + port`
// Wildcard `*` can be used, but has to be set explicitly.
// Example: `https://example.com`, `http://example.com:8080`, `*`
//
// Security: Poorly configured CORS can compromise security because it allows
// relaxation of the browser's Same-Origin policy. See [Exploiting CORS
// misconfigurations for Bitcoins and bounties] and [Portswigger: Cross-origin
// resource sharing (CORS)] for more details.
//
// [MDN: Cross-Origin Resource Sharing (CORS)]: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
// [Exploiting CORS misconfigurations for Bitcoins and bounties]: https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
// [Portswigger: Cross-origin resource sharing (CORS)]: https://portswigger.net/web-security/cors
func CORS(allowOrigins ...string) echo.MiddlewareFunc {
c := CORSConfig{
AllowOrigins: allowOrigins,
}
return CORSWithConfig(c)
}
// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration.
// See: [CORS].
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration
func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
hasCustomAllowMethods := true
if len(config.AllowMethods) == 0 {
hasCustomAllowMethods = false
config.AllowMethods = []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}
}
allowMethods := strings.Join(config.AllowMethods, ",")
allowHeaders := strings.Join(config.AllowHeaders, ",")
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
maxAge := "0"
if config.MaxAge > 0 {
maxAge = strconv.Itoa(config.MaxAge)
}
allowOriginFunc := config.UnsafeAllowOriginFunc
if config.UnsafeAllowOriginFunc == nil {
if len(config.AllowOrigins) == 0 {
return nil, errors.New("at least one AllowOrigins is required or UnsafeAllowOriginFunc must be provided")
}
allowOriginFunc = config.defaultAllowOriginFunc
for _, origin := range config.AllowOrigins {
if origin == "*" {
if config.AllowCredentials {
return nil, fmt.Errorf("* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc")
}
allowOriginFunc = config.starAllowOriginFunc
break
}
if err := validateOrigin(origin, "allow origin"); err != nil {
return nil, err
}
}
config.AllowOrigins = append([]string(nil), config.AllowOrigins...)
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
origin := req.Header.Get(echo.HeaderOrigin)
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
// For simplicity we just consider method type and later `Origin` header.
preflight := req.Method == http.MethodOptions
// Although router adds special handler in case of OPTIONS method we avoid calling next for OPTIONS in this middleware
// as CORS requests do not have cookies / authentication headers by default, so we could get stuck in auth
// middlewares by calling next(c).
// But we still want to send `Allow` header as response in case of Non-CORS OPTIONS request as router default
// handler does.
routerAllowMethods := ""
if preflight {
tmpAllowMethods, ok := c.Get(echo.ContextKeyHeaderAllow).(string)
if ok && tmpAllowMethods != "" {
routerAllowMethods = tmpAllowMethods
c.Response().Header().Set(echo.HeaderAllow, routerAllowMethods)
}
}
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
if origin == "" {
if preflight { // req.Method=OPTIONS
return c.NoContent(http.StatusNoContent)
}
return next(c) // let non-browser calls through
}
allowedOrigin, allowed, err := allowOriginFunc(c, origin)
if err != nil {
return err
}
if !allowed {
// Origin existed and was NOT allowed
if preflight {
// From: https://github.com/labstack/echo/issues/2767
// If the request's origin isn't allowed by the CORS configuration,
// the middleware should simply omit the relevant CORS headers from the response
// and let the browser fail the CORS check (if any).
return c.NoContent(http.StatusNoContent)
}
// From: https://github.com/labstack/echo/issues/2767
// no CORS middleware should block non-preflight requests;
// such requests should be let through. One reason is that not all requests that
// carry an Origin header participate in the CORS protocol.
return next(c)
}
// Origin existed and was allowed
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
}
// Simple request will be let though
if !preflight {
if exposeHeaders != "" {
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
}
return next(c)
}
// Below code is for Preflight (OPTIONS) request
//
// Preflight will end with c.NoContent(http.StatusNoContent) as we do not know if
// at the end of handler chain is actual OPTIONS route or 404/405 route which
// response code will confuse browsers
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
if !hasCustomAllowMethods && routerAllowMethods != "" {
res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods)
} else {
res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
}
if allowHeaders != "" {
res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := req.Header.Get(echo.HeaderAccessControlRequestHeaders)
if h != "" {
res.Header().Set(echo.HeaderAccessControlAllowHeaders, h)
}
}
if config.MaxAge != 0 {
res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge)
}
return c.NoContent(http.StatusNoContent)
}
}, nil
}
func (config CORSConfig) starAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) {
return "*", true, nil
}
func (config CORSConfig) defaultAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) {
for _, allowedOrigin := range config.AllowOrigins {
if strings.EqualFold(allowedOrigin, origin) {
return allowedOrigin, true, nil
}
}
return "", false, nil
}
================================================
FILE: middleware/cors_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"cmp"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestCORS(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodOptions, "/", nil) // Preflight request
req.Header.Set(echo.HeaderOrigin, "http://example.com")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw := CORS("*")
handler := mw(func(c *echo.Context) error {
return nil
})
err := handler(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusNoContent, rec.Code)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
}
func TestCORSConfig(t *testing.T) {
var testCases = []struct {
name string
givenConfig *CORSConfig
whenMethod string
whenHeaders map[string]string
expectHeaders map[string]string
notExpectHeaders map[string]string
expectErr string
}{
{
name: "ok, wildcard origin",
givenConfig: &CORSConfig{
AllowOrigins: []string{"*"},
},
whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"},
expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "*"},
},
{
name: "ok, wildcard AllowedOrigin with no Origin header in request",
givenConfig: &CORSConfig{
AllowOrigins: []string{"*"},
},
notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""},
},
{
name: "ok, specific AllowOrigins and AllowCredentials",
givenConfig: &CORSConfig{
AllowOrigins: []string{"http://localhost", "http://localhost:8080"},
AllowCredentials: true,
MaxAge: 3600,
},
whenHeaders: map[string]string{echo.HeaderOrigin: "http://localhost"},
expectHeaders: map[string]string{
echo.HeaderAccessControlAllowOrigin: "http://localhost",
echo.HeaderAccessControlAllowCredentials: "true",
},
},
{
name: "ok, preflight request with matching origin for `AllowOrigins`",
givenConfig: &CORSConfig{
AllowOrigins: []string{"http://localhost"},
AllowCredentials: true,
MaxAge: 3600,
},
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "http://localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectHeaders: map[string]string{
echo.HeaderAccessControlAllowOrigin: "http://localhost",
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
echo.HeaderAccessControlAllowCredentials: "true",
echo.HeaderAccessControlMaxAge: "3600",
},
},
{
name: "ok, preflight request when `Access-Control-Max-Age` is set",
givenConfig: &CORSConfig{
AllowOrigins: []string{"http://localhost"},
AllowCredentials: true,
MaxAge: 1,
},
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "http://localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectHeaders: map[string]string{
echo.HeaderAccessControlMaxAge: "1",
},
},
{
name: "ok, preflight request when `Access-Control-Max-Age` is set to 0 - not to cache response",
givenConfig: &CORSConfig{
AllowOrigins: []string{"http://localhost"},
AllowCredentials: true,
MaxAge: -1, // forces `Access-Control-Max-Age: 0`
},
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "http://localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectHeaders: map[string]string{
echo.HeaderAccessControlMaxAge: "0",
},
},
{
name: "ok, CORS check are skipped",
givenConfig: &CORSConfig{
AllowOrigins: []string{"http://localhost"},
AllowCredentials: true,
Skipper: func(c *echo.Context) bool {
return true
},
},
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "http://localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
notExpectHeaders: map[string]string{
echo.HeaderAccessControlAllowOrigin: "localhost",
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
echo.HeaderAccessControlAllowCredentials: "true",
echo.HeaderAccessControlMaxAge: "3600",
},
},
{
name: "nok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
givenConfig: &CORSConfig{
AllowOrigins: []string{"*"},
AllowCredentials: true,
MaxAge: 3600,
},
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`,
},
{
name: "nok, preflight request with invalid `AllowOrigins` value",
givenConfig: &CORSConfig{
AllowOrigins: []string{"http://server", "missing-scheme"},
},
expectErr: `allow origin is missing scheme or host: missing-scheme`,
},
{
name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` false",
givenConfig: &CORSConfig{
AllowOrigins: []string{"*"},
AllowCredentials: false, // important for this testcase
MaxAge: 3600,
},
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectHeaders: map[string]string{
echo.HeaderAccessControlAllowOrigin: "*",
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
echo.HeaderAccessControlMaxAge: "3600",
},
notExpectHeaders: map[string]string{
echo.HeaderAccessControlAllowCredentials: "",
},
},
{
name: "ok, INSECURE preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
givenConfig: &CORSConfig{
AllowOrigins: []string{"*"},
AllowCredentials: true,
MaxAge: 3600,
},
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
},
expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`,
},
{
name: "ok, preflight request with Access-Control-Request-Headers",
givenConfig: &CORSConfig{
AllowOrigins: []string{"*"},
},
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{
echo.HeaderOrigin: "localhost",
echo.HeaderContentType: echo.MIMEApplicationJSON,
echo.HeaderAccessControlRequestHeaders: "Special-Request-Header",
},
expectHeaders: map[string]string{
echo.HeaderAccessControlAllowOrigin: "*",
echo.HeaderAccessControlAllowHeaders: "Special-Request-Header",
echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
},
},
{
name: "ok, preflight request with `AllowOrigins` which allow all subdomains aaa with *",
givenConfig: &CORSConfig{
UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error) {
if strings.HasSuffix(origin, ".example.com") {
allowed = true
}
return origin, allowed, nil
},
},
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"},
expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"},
},
{
name: "ok, preflight request with `AllowOrigins` which allow all subdomains bbb with *",
givenConfig: &CORSConfig{
UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) {
if strings.HasSuffix(origin, ".example.com") {
return origin, true, nil
}
return "", false, nil
},
},
whenMethod: http.MethodOptions,
whenHeaders: map[string]string{echo.HeaderOrigin: "http://bbb.example.com"},
expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://bbb.example.com"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
var mw echo.MiddlewareFunc
var err error
if tc.givenConfig != nil {
mw, err = tc.givenConfig.ToMiddleware()
} else {
mw, err = CORSConfig{}.ToMiddleware()
}
if err != nil {
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
return
}
t.Fatal(err)
}
h := mw(func(c *echo.Context) error {
return nil
})
method := cmp.Or(tc.whenMethod, http.MethodGet)
req := httptest.NewRequest(method, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
for k, v := range tc.whenHeaders {
req.Header.Set(k, v)
}
err = h(c)
assert.NoError(t, err)
header := rec.Header()
for k, v := range tc.expectHeaders {
assert.Equal(t, v, header.Get(k), "header: `%v` should be `%v`", k, v)
}
for k, v := range tc.notExpectHeaders {
if v == "" {
assert.Len(t, header.Values(k), 0, "header: `%v` should not be set", k)
} else {
assert.NotEqual(t, v, header.Get(k), "header: `%v` should not be `%v`", k, v)
}
}
})
}
}
func Test_allowOriginScheme(t *testing.T) {
tests := []struct {
domain, pattern string
expected bool
}{
{
domain: "http://example.com",
pattern: "http://example.com",
expected: true,
},
{
domain: "https://example.com",
pattern: "https://example.com",
expected: true,
},
{
domain: "http://example.com",
pattern: "https://example.com",
expected: false,
},
{
domain: "https://example.com",
pattern: "http://example.com",
expected: false,
},
}
e := echo.New()
for _, tt := range tests {
req := httptest.NewRequest(http.MethodOptions, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, tt.domain)
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.pattern},
})
h := cors(func(c *echo.Context) error { return echo.ErrNotFound })
h(c)
if tt.expected {
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
}
}
}
func TestCORSWithConfig_AllowMethods(t *testing.T) {
var testCases = []struct {
name string
givenAllowOrigins []string
givenAllowMethods []string
whenAllowContextKey string
whenOrigin string
expectAllow string
expectAccessControlAllowMethods string
}{
{
name: "custom AllowMethods, preflight, no origin, sets only allow header from context key",
givenAllowOrigins: []string{"*"},
givenAllowMethods: []string{http.MethodGet, http.MethodHead},
whenAllowContextKey: "OPTIONS, GET",
whenOrigin: "",
expectAllow: "OPTIONS, GET",
},
{
name: "default AllowMethods, preflight, no origin, no allow header in context key and in response",
givenAllowOrigins: []string{"*"},
givenAllowMethods: nil,
whenAllowContextKey: "",
whenOrigin: "",
expectAllow: "",
},
{
name: "custom AllowMethods, preflight, existing origin, sets both headers different values",
givenAllowOrigins: []string{"*"},
givenAllowMethods: []string{http.MethodGet, http.MethodHead},
whenAllowContextKey: "OPTIONS, GET",
whenOrigin: "http://google.com",
expectAllow: "OPTIONS, GET",
expectAccessControlAllowMethods: "GET,HEAD",
},
{
name: "default AllowMethods, preflight, existing origin, sets both headers",
givenAllowOrigins: []string{"*"},
givenAllowMethods: nil,
whenAllowContextKey: "OPTIONS, GET",
whenOrigin: "http://google.com",
expectAllow: "OPTIONS, GET",
expectAccessControlAllowMethods: "OPTIONS, GET",
},
{
name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods",
givenAllowOrigins: []string{"*"},
givenAllowMethods: nil,
whenAllowContextKey: "",
whenOrigin: "http://google.com",
expectAllow: "",
expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/test", func(c *echo.Context) error {
return c.String(http.StatusOK, "OK")
})
cors := CORSWithConfig(CORSConfig{
AllowOrigins: tc.givenAllowOrigins,
AllowMethods: tc.givenAllowMethods,
})
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, tc.whenOrigin)
if tc.whenAllowContextKey != "" {
c.Set(echo.ContextKeyHeaderAllow, tc.whenAllowContextKey)
}
h := cors(func(c *echo.Context) error {
return c.String(http.StatusOK, "OK")
})
h(c)
assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow))
assert.Equal(t, tc.expectAccessControlAllowMethods, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
})
}
}
func TestCorsHeaders(t *testing.T) {
tests := []struct {
name string
originDomain string
method string
allowedOrigin string
expected bool
expectStatus int
expectAllowHeader string
}{
{
name: "non-preflight request, allow any origin, missing origin header = no CORS logic done",
originDomain: "",
allowedOrigin: "*",
method: http.MethodGet,
expected: false,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow any origin, specific origin domain",
originDomain: "http://example.com",
allowedOrigin: "*",
method: http.MethodGet,
expected: true,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow specific origin, missing origin header = no CORS logic done",
originDomain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow specific origin, different origin header = CORS logic failure",
originDomain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow specific origin, matching origin header = CORS logic done",
originDomain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: true,
expectStatus: http.StatusOK,
},
{
name: "preflight, allow any origin, missing origin header = no CORS logic done",
originDomain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodOptions,
expected: false,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow any origin, existing origin header = CORS logic done",
originDomain: "http://example.com",
allowedOrigin: "*",
method: http.MethodOptions,
expected: true,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow any origin, missing origin header = no CORS logic done",
originDomain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow specific origin, different origin header = no CORS logic done",
originDomain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow specific origin, matching origin header = CORS logic done",
originDomain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: true,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Use(CORSWithConfig(CORSConfig{
AllowOrigins: []string{tc.allowedOrigin},
//AllowCredentials: true,
//MaxAge: 3600,
}))
e.GET("/", func(c *echo.Context) error {
return c.String(http.StatusOK, "OK")
})
e.POST("/", func(c *echo.Context) error {
return c.String(http.StatusCreated, "OK")
})
req := httptest.NewRequest(tc.method, "/", nil)
rec := httptest.NewRecorder()
if tc.originDomain != "" {
req.Header.Set(echo.HeaderOrigin, tc.originDomain)
}
// we run through whole Echo handler chain to see how CORS works with Router OPTIONS handler
e.ServeHTTP(rec, req)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
assert.Equal(t, tc.expectAllowHeader, rec.Header().Get(echo.HeaderAllow))
assert.Equal(t, tc.expectStatus, rec.Code)
expectedAllowOrigin := ""
if tc.allowedOrigin == "*" {
expectedAllowOrigin = "*"
} else {
expectedAllowOrigin = tc.originDomain
}
switch {
case tc.expected && tc.method == http.MethodOptions:
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
case tc.expected && tc.method == http.MethodGet:
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
default:
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
}
})
}
}
func Test_allowOriginFunc(t *testing.T) {
returnTrue := func(c *echo.Context, origin string) (string, bool, error) {
return origin, true, nil
}
returnFalse := func(c *echo.Context, origin string) (string, bool, error) {
return origin, false, nil
}
returnError := func(c *echo.Context, origin string) (string, bool, error) {
return origin, true, errors.New("this is a test error")
}
allowOriginFuncs := []func(c *echo.Context, origin string) (string, bool, error){
returnTrue,
returnFalse,
returnError,
}
const origin = "http://example.com"
e := echo.New()
for _, allowOriginFunc := range allowOriginFuncs {
req := httptest.NewRequest(http.MethodOptions, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin)
cors, err := CORSConfig{UnsafeAllowOriginFunc: allowOriginFunc}.ToMiddleware()
assert.NoError(t, err)
h := cors(func(c *echo.Context) error { return echo.ErrNotFound })
err = h(c)
allowedOrigin, allowed, expectedErr := allowOriginFunc(c, origin)
if expectedErr != nil {
assert.Equal(t, expectedErr, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
continue
}
if allowed {
assert.Equal(t, allowedOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
}
}
}
================================================
FILE: middleware/csrf.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"crypto/subtle"
"net/http"
"slices"
"strings"
"time"
"github.com/labstack/echo/v5"
)
// CSRFUsingSecFetchSite is a context key for CSRF middleware what is set when the client browser is using Sec-Fetch-Site
// header and the request is deemed safe.
// It is a dummy token value that can be used to render CSRF token for form by handlers.
//
// We know that the client is using a browser that supports Sec-Fetch-Site header, so when the form is submitted in
// the future with this dummy token value it is OK. Although the request is safe, the template rendered by the
// handler may need this value to render CSRF token for form.
const CSRFUsingSecFetchSite = "_echo_csrf_using_sec_fetch_site_"
// CSRFConfig defines the config for CSRF middleware.
type CSRFConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// TrustedOrigin permits any request with `Sec-Fetch-Site` header whose `Origin` header
// exactly matches the specified value.
// Values should be formatted as Origin header "scheme://host[:port]".
//
// See [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
// See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
TrustedOrigins []string
// AllowSecFetchSameSite allows custom behaviour for `Sec-Fetch-Site` requests that are about to
// fail with CRSF error, to be allowed or replaced with custom error.
// This function applies to `Sec-Fetch-Site` values:
// - `same-site` same registrable domain (subdomain and/or different port)
// - `cross-site` request originates from different site
// See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
AllowSecFetchSiteFunc func(c *echo.Context) (bool, error)
// TokenLength is the length of the generated token.
TokenLength uint8
// Optional. Default value 32.
// TokenLookup is a string in the form of ":" or ":,:" that is used
// to extract token from the request.
// Optional. Default value "header:X-CSRF-Token".
// Possible values:
// - "header:" or "header::"
// - "query:"
// - "form:"
// Multiple sources example:
// - "header:X-CSRF-Token,query:csrf"
TokenLookup string `yaml:"token_lookup"`
// Generator defines a function to generate token.
// Optional. Defaults tp randomString(TokenLength).
Generator func() string
// Context key to store generated CSRF token into context.
// Optional. Default value "csrf".
ContextKey string
// Name of the CSRF cookie. This cookie will store CSRF token.
// Optional. Default value "csrf".
CookieName string
// Domain of the CSRF cookie.
// Optional. Default value none.
CookieDomain string
// Path of the CSRF cookie.
// Optional. Default value none.
CookiePath string
// Max age (in seconds) of the CSRF cookie.
// Optional. Default value 86400 (24hr).
CookieMaxAge int
// Indicates if CSRF cookie is secure.
// Optional. Default value false.
CookieSecure bool
// Indicates if CSRF cookie is HTTP only.
// Optional. Default value false.
CookieHTTPOnly bool
// Indicates SameSite mode of the CSRF cookie.
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite
// ErrorHandler defines a function which is executed for returning custom errors.
ErrorHandler func(c *echo.Context, err error) error
}
// ErrCSRFInvalid is returned when CSRF check fails
var ErrCSRFInvalid = &echo.HTTPError{Code: http.StatusForbidden, Message: "invalid csrf token"}
// DefaultCSRFConfig is the default CSRF middleware config.
var DefaultCSRFConfig = CSRFConfig{
Skipper: DefaultSkipper,
TokenLength: 32,
TokenLookup: "header:" + echo.HeaderXCSRFToken,
ContextKey: "csrf",
CookieName: "_csrf",
CookieMaxAge: 86400,
CookieSameSite: http.SameSiteDefaultMode,
}
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
func CSRF() echo.MiddlewareFunc {
return CSRFWithConfig(DefaultCSRFConfig)
}
// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration.
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration
func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultCSRFConfig.Skipper
}
if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength
}
if config.Generator == nil {
config.Generator = createRandomStringGenerator(config.TokenLength)
}
if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup
}
if config.ContextKey == "" {
config.ContextKey = DefaultCSRFConfig.ContextKey
}
if config.CookieName == "" {
config.CookieName = DefaultCSRFConfig.CookieName
}
if config.CookieMaxAge == 0 {
config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
}
if config.CookieSameSite == http.SameSiteNoneMode {
config.CookieSecure = true
}
if len(config.TrustedOrigins) > 0 {
if err := validateOrigins(config.TrustedOrigins, "trusted origin"); err != nil {
return nil, err
}
config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...)
}
extractors, cErr := createExtractors(config.TokenLookup, 1)
if cErr != nil {
return nil, cErr
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
// use the `Sec-Fetch-Site` header as part of a modern approach to CSRF protection
allow, err := config.checkSecFetchSiteRequest(c)
if err != nil {
return err
}
if allow {
return next(c)
}
// Fallback to legacy token based CSRF protection
token := ""
if k, err := c.Cookie(config.CookieName); err != nil {
token = config.Generator() // Generate token
} else {
token = k.Value // Reuse token
}
switch c.Request().Method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
default:
// Validate token only for requests which are not defined as 'safe' by RFC7231
var lastExtractorErr error
var lastTokenErr error
outer:
for _, extractor := range extractors {
clientTokens, _, err := extractor(c)
if err != nil {
lastExtractorErr = err
continue
}
for _, clientToken := range clientTokens {
if validateCSRFToken(token, clientToken) {
lastTokenErr = nil
lastExtractorErr = nil
break outer
}
lastTokenErr = ErrCSRFInvalid
}
}
var finalErr error
if lastTokenErr != nil {
finalErr = lastTokenErr
} else if lastExtractorErr != nil {
finalErr = echo.ErrBadRequest.Wrap(lastExtractorErr)
}
if finalErr != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(c, finalErr)
}
return finalErr
}
}
// Set CSRF cookie
cookie := new(http.Cookie)
cookie.Name = config.CookieName
cookie.Value = token
if config.CookiePath != "" {
cookie.Path = config.CookiePath
}
if config.CookieDomain != "" {
cookie.Domain = config.CookieDomain
}
if config.CookieSameSite != http.SameSiteDefaultMode {
cookie.SameSite = config.CookieSameSite
}
cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
cookie.Secure = config.CookieSecure
cookie.HttpOnly = config.CookieHTTPOnly
c.SetCookie(cookie)
// Store token in the context
c.Set(config.ContextKey, token)
// Protect clients from caching the response
c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
return next(c)
}
}, nil
}
func validateCSRFToken(token, clientToken string) bool {
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
}
var safeMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace}
func (config CSRFConfig) checkSecFetchSiteRequest(c *echo.Context) (bool, error) {
// https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
// Sec-Fetch-Site values are:
// - `same-origin` exact origin match - allow always
// - `same-site` same registrable domain (subdomain and/or different port) - block, unless explicitly trusted
// - `cross-site` request originates from different site - block, unless explicitly trusted
// - `none` direct navigation (URL bar, bookmark) - allow always
secFetchSite := c.Request().Header.Get(echo.HeaderSecFetchSite)
if secFetchSite == "" {
return false, nil
}
if len(config.TrustedOrigins) > 0 {
// trusted sites ala OAuth callbacks etc. should be let through
origin := c.Request().Header.Get(echo.HeaderOrigin)
if origin != "" {
for _, trustedOrigin := range config.TrustedOrigins {
if strings.EqualFold(origin, trustedOrigin) {
return true, nil
}
}
}
}
isSafe := slices.Contains(safeMethods, c.Request().Method)
if !isSafe { // for state-changing request check SecFetchSite value
isSafe = secFetchSite == "same-origin" || secFetchSite == "none"
}
if isSafe {
// This helps handlers that support older token-based CSRF protection.
// We know that the client is using a browser that supports Sec-Fetch-Site header, so when the form is submitted in
// the future with this dummy token value it is OK. Although the request is safe, the template rendered by the
// handler may need this value to render CSRF token for form.
c.Set(config.ContextKey, CSRFUsingSecFetchSite)
return true, nil
}
// we are here when request is state-changing and `cross-site` or `same-site`
// Note: if you want to allow `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc`
if config.AllowSecFetchSiteFunc != nil {
return config.AllowSecFetchSiteFunc(c)
}
if secFetchSite == "same-site" {
return false, echo.NewHTTPError(http.StatusForbidden, "same-site request blocked by CSRF")
}
return false, echo.NewHTTPError(http.StatusForbidden, "cross-site request blocked by CSRF")
}
================================================
FILE: middleware/csrf_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"cmp"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestCSRF_tokenExtractors(t *testing.T) {
var testCases = []struct {
name string
whenTokenLookup string
whenCookieName string
givenCSRFCookie string
givenMethod string
givenQueryTokens map[string][]string
givenFormTokens map[string][]string
givenHeaderTokens map[string][]string
expectError string
expectToMiddlewareError string
}{
{
name: "ok, multiple token lookups sources, succeeds on last one",
whenTokenLookup: "header:X-CSRF-Token,form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"invalid_token"},
},
givenFormTokens: map[string][]string{
"csrf": {"token"},
},
},
{
name: "ok, token from POST form",
whenTokenLookup: "form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{
"csrf": {"token"},
},
},
{
name: "ok, token from POST form, second token passes",
whenTokenLookup: "form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{
"csrf": {"invalid", "token"},
},
expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, invalid token from POST form",
whenTokenLookup: "form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{
"csrf": {"invalid_token"},
},
expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, missing token from POST form",
whenTokenLookup: "form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{},
expectError: "code=400, message=Bad Request, err=missing value in the form",
},
{
name: "ok, token from POST header",
whenTokenLookup: "", // will use defaults
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"token"},
},
},
{
name: "nok, token from POST header, tokens limited to 1, second token would pass",
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"invalid", "token"},
},
expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, invalid token from POST header",
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"invalid_token"},
},
expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, missing token from POST header",
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{},
expectError: "code=400, message=Bad Request, err=missing value in request header",
},
{
name: "ok, token from PUT query param",
whenTokenLookup: "query:csrf-param",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{
"csrf-param": {"token"},
},
},
{
name: "nok, token from PUT query form, second token would pass",
whenTokenLookup: "query:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{
"csrf": {"invalid", "token"},
},
expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, invalid token from PUT query form",
whenTokenLookup: "query:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{
"csrf": {"invalid_token"},
},
expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, missing token from PUT query form",
whenTokenLookup: "query:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{},
expectError: "code=400, message=Bad Request, err=missing value in the query string",
},
{
name: "nok, invalid TokenLookup",
whenTokenLookup: "q",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{},
expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
q := make(url.Values)
for queryParam, values := range tc.givenQueryTokens {
for _, v := range values {
q.Add(queryParam, v)
}
}
f := make(url.Values)
for formKey, values := range tc.givenFormTokens {
for _, v := range values {
f.Add(formKey, v)
}
}
var req *http.Request
switch tc.givenMethod {
case http.MethodGet:
req = httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil)
case http.MethodPost, http.MethodPut:
req = httptest.NewRequest(http.MethodPost, "/?"+q.Encode(), strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
}
for header, values := range tc.givenHeaderTokens {
for _, v := range values {
req.Header.Add(header, v)
}
}
if tc.givenCSRFCookie != "" {
req.Header.Set(echo.HeaderCookie, "_csrf="+tc.givenCSRFCookie)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
config := CSRFConfig{
TokenLookup: tc.whenTokenLookup,
CookieName: tc.whenCookieName,
}
csrf, err := config.ToMiddleware()
if tc.expectToMiddlewareError != "" {
assert.EqualError(t, err, tc.expectToMiddlewareError)
return
} else if err != nil {
assert.NoError(t, err)
}
h := csrf(func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
})
err = h(c)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestCSRFWithConfig(t *testing.T) {
token := randomString(16)
var testCases = []struct {
name string
givenConfig *CSRFConfig
whenMethod string
whenHeaders map[string]string
expectEmptyBody bool
expectMWError string
expectCookieContains string
expectTokenInContext string
expectErr string
}{
{
name: "ok, GET",
whenMethod: http.MethodGet,
expectCookieContains: "_csrf",
expectTokenInContext: "TESTTOKEN",
},
{
name: "ok, POST valid token",
whenHeaders: map[string]string{
echo.HeaderCookie: "_csrf=" + token,
echo.HeaderXCSRFToken: token,
},
whenMethod: http.MethodPost,
expectCookieContains: "_csrf",
expectTokenInContext: token,
},
{
name: "nok, POST without token",
whenMethod: http.MethodPost,
expectEmptyBody: true,
expectErr: `code=400, message=Bad Request, err=missing value in request header`,
},
{
name: "nok, POST empty token",
whenHeaders: map[string]string{echo.HeaderXCSRFToken: ""},
whenMethod: http.MethodPost,
expectEmptyBody: true,
expectErr: `code=403, message=invalid csrf token`,
},
{
name: "nok, invalid trusted origin in Config",
givenConfig: &CSRFConfig{
TrustedOrigins: []string{"http://example.com", "invalid"},
},
expectMWError: `trusted origin is missing scheme or host: invalid`,
},
{
name: "ok, TokenLength",
givenConfig: &CSRFConfig{
TokenLength: 16,
},
whenMethod: http.MethodGet,
expectCookieContains: "_csrf",
expectTokenInContext: "TESTTOKEN",
},
{
name: "ok, unsafe method + SecFetchSite=same-origin passes",
whenHeaders: map[string]string{
echo.HeaderSecFetchSite: "same-origin",
},
whenMethod: http.MethodPost,
expectTokenInContext: "_echo_csrf_using_sec_fetch_site_",
},
{
name: "ok, safe method + SecFetchSite=same-origin passes",
whenHeaders: map[string]string{
echo.HeaderSecFetchSite: "same-origin",
},
whenMethod: http.MethodGet,
expectTokenInContext: "_echo_csrf_using_sec_fetch_site_",
},
{
name: "nok, unsafe method + SecFetchSite=same-cross blocked",
whenHeaders: map[string]string{
echo.HeaderSecFetchSite: "same-cross",
},
whenMethod: http.MethodPost,
expectEmptyBody: true,
expectErr: `code=403, message=cross-site request blocked by CSRF`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(cmp.Or(tc.whenMethod, http.MethodPost), "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
for key, value := range tc.whenHeaders {
req.Header.Set(key, value)
}
config := CSRFConfig{}
if tc.givenConfig != nil {
config = *tc.givenConfig
}
if config.Generator == nil {
config.Generator = func() string {
return "TESTTOKEN"
}
}
mw, err := config.ToMiddleware()
if tc.expectMWError != "" {
assert.EqualError(t, err, tc.expectMWError)
return
}
assert.NoError(t, err)
h := mw(func(c *echo.Context) error {
cToken := c.Get(cmp.Or(config.ContextKey, DefaultCSRFConfig.ContextKey))
assert.Equal(t, tc.expectTokenInContext, cToken)
return c.String(http.StatusOK, "test")
})
err = h(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
expect := "test"
if tc.expectEmptyBody {
expect = ""
}
assert.Equal(t, expect, rec.Body.String())
if tc.expectCookieContains != "" {
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), tc.expectCookieContains)
}
})
}
}
func TestCSRF(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRF()
h := csrf(func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Generate CSRF token
h(c)
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
}
func TestCSRFSetSameSiteMode(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
CookieSameSite: http.SameSiteStrictMode,
})
h := csrf(func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
})
r := h(c)
assert.NoError(t, r)
assert.Regexp(t, "SameSite=Strict", rec.Header()["Set-Cookie"])
}
func TestCSRFWithoutSameSiteMode(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{})
h := csrf(func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
})
r := h(c)
assert.NoError(t, r)
assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"])
}
func TestCSRFWithSameSiteDefaultMode(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
CookieSameSite: http.SameSiteDefaultMode,
})
h := csrf(func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
})
r := h(c)
assert.NoError(t, r)
assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"])
}
func TestCSRFWithSameSiteModeNone(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf, err := CSRFConfig{
CookieSameSite: http.SameSiteNoneMode,
}.ToMiddleware()
assert.NoError(t, err)
h := csrf(func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
})
r := h(c)
assert.NoError(t, r)
assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"])
assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"])
}
func TestCSRFConfig_skipper(t *testing.T) {
var testCases = []struct {
name string
whenSkip bool
expectCookies int
}{
{
name: "do skip",
whenSkip: true,
expectCookies: 0,
},
{
name: "do not skip",
whenSkip: false,
expectCookies: 1,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
Skipper: func(c *echo.Context) bool {
return tc.whenSkip
},
})
h := csrf(func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
})
r := h(c)
assert.NoError(t, r)
cookie := rec.Header()["Set-Cookie"]
assert.Len(t, cookie, tc.expectCookies)
})
}
}
func TestCSRFErrorHandling(t *testing.T) {
cfg := CSRFConfig{
ErrorHandler: func(c *echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed")
},
}
e := echo.New()
e.POST("/", func(c *echo.Context) error {
return c.String(http.StatusNotImplemented, "should not end up here")
})
e.Use(CSRFWithConfig(cfg))
req := httptest.NewRequest(http.MethodPost, "/", nil)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String())
}
func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
var testCases = []struct {
name string
givenConfig CSRFConfig
whenMethod string
whenSecFetchSite string
whenOrigin string
expectAllow bool
expectErr string
}{
{
name: "ok, unsafe POST, no SecFetchSite is not blocked",
givenConfig: CSRFConfig{},
whenMethod: http.MethodPost,
whenSecFetchSite: "",
expectAllow: false, // should fall back to token CSRF
},
{
name: "ok, safe GET + same-origin passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodGet,
whenSecFetchSite: "same-origin",
expectAllow: true,
},
{
name: "ok, safe GET + none passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodGet,
whenSecFetchSite: "none",
expectAllow: true,
},
{
name: "ok, safe GET + same-site passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodGet,
whenSecFetchSite: "same-site",
expectAllow: true,
},
{
name: "ok, safe GET + cross-site passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodGet,
whenSecFetchSite: "cross-site",
expectAllow: true,
},
{
name: "nok, unsafe POST + cross-site is blocked",
givenConfig: CSRFConfig{},
whenMethod: http.MethodPost,
whenSecFetchSite: "cross-site",
expectAllow: false,
expectErr: `code=403, message=cross-site request blocked by CSRF`,
},
{
name: "nok, unsafe POST + same-site is blocked",
givenConfig: CSRFConfig{},
whenMethod: http.MethodPost,
whenSecFetchSite: "same-site",
expectAllow: false,
expectErr: `code=403, message=same-site request blocked by CSRF`,
},
{
name: "ok, unsafe POST + same-origin passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodPost,
whenSecFetchSite: "same-origin",
expectAllow: true,
},
{
name: "ok, unsafe POST + none passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodPost,
whenSecFetchSite: "none",
expectAllow: true,
},
{
name: "ok, unsafe PUT + same-origin passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodPut,
whenSecFetchSite: "same-origin",
expectAllow: true,
},
{
name: "ok, unsafe PUT + none passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodPut,
whenSecFetchSite: "none",
expectAllow: true,
},
{
name: "ok, unsafe DELETE + same-origin passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodDelete,
whenSecFetchSite: "same-origin",
expectAllow: true,
},
{
name: "ok, unsafe PATCH + same-origin passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodPatch,
whenSecFetchSite: "same-origin",
expectAllow: true,
},
{
name: "nok, unsafe PUT + cross-site is blocked",
givenConfig: CSRFConfig{},
whenMethod: http.MethodPut,
whenSecFetchSite: "cross-site",
expectAllow: false,
expectErr: `code=403, message=cross-site request blocked by CSRF`,
},
{
name: "nok, unsafe PUT + same-site is blocked",
givenConfig: CSRFConfig{},
whenMethod: http.MethodPut,
whenSecFetchSite: "same-site",
expectAllow: false,
expectErr: `code=403, message=same-site request blocked by CSRF`,
},
{
name: "nok, unsafe DELETE + cross-site is blocked",
givenConfig: CSRFConfig{},
whenMethod: http.MethodDelete,
whenSecFetchSite: "cross-site",
expectAllow: false,
expectErr: `code=403, message=cross-site request blocked by CSRF`,
},
{
name: "nok, unsafe DELETE + same-site is blocked",
givenConfig: CSRFConfig{},
whenMethod: http.MethodDelete,
whenSecFetchSite: "same-site",
expectAllow: false,
expectErr: `code=403, message=same-site request blocked by CSRF`,
},
{
name: "nok, unsafe PATCH + cross-site is blocked",
givenConfig: CSRFConfig{},
whenMethod: http.MethodPatch,
whenSecFetchSite: "cross-site",
expectAllow: false,
expectErr: `code=403, message=cross-site request blocked by CSRF`,
},
{
name: "ok, safe HEAD + same-origin passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodHead,
whenSecFetchSite: "same-origin",
expectAllow: true,
},
{
name: "ok, safe HEAD + cross-site passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodHead,
whenSecFetchSite: "cross-site",
expectAllow: true,
},
{
name: "ok, safe OPTIONS + cross-site passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodOptions,
whenSecFetchSite: "cross-site",
expectAllow: true,
},
{
name: "ok, safe TRACE + cross-site passes",
givenConfig: CSRFConfig{},
whenMethod: http.MethodTrace,
whenSecFetchSite: "cross-site",
expectAllow: true,
},
{
name: "ok, unsafe POST + cross-site + matching trusted origin passes",
givenConfig: CSRFConfig{
TrustedOrigins: []string{"https://trusted.example.com"},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "cross-site",
whenOrigin: "https://trusted.example.com",
expectAllow: true,
},
{
name: "ok, unsafe POST + same-site + matching trusted origin passes",
givenConfig: CSRFConfig{
TrustedOrigins: []string{"https://trusted.example.com"},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "same-site",
whenOrigin: "https://trusted.example.com",
expectAllow: true,
},
{
name: "nok, unsafe POST + cross-site + non-matching origin is blocked",
givenConfig: CSRFConfig{
TrustedOrigins: []string{"https://trusted.example.com"},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "cross-site",
whenOrigin: "https://evil.example.com",
expectAllow: false,
expectErr: `code=403, message=cross-site request blocked by CSRF`,
},
{
name: "ok, unsafe POST + cross-site + case-insensitive trusted origin match passes",
givenConfig: CSRFConfig{
TrustedOrigins: []string{"https://trusted.example.com"},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "cross-site",
whenOrigin: "https://TRUSTED.example.com",
expectAllow: true,
},
{
name: "ok, unsafe POST + same-origin + trusted origins configured but not matched passes",
givenConfig: CSRFConfig{
TrustedOrigins: []string{"https://trusted.example.com"},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "same-origin",
whenOrigin: "https://different.example.com",
expectAllow: true,
},
{
name: "nok, unsafe POST + cross-site + empty origin + trusted origins configured is blocked",
givenConfig: CSRFConfig{
TrustedOrigins: []string{"https://trusted.example.com"},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "cross-site",
whenOrigin: "",
expectAllow: false,
expectErr: `code=403, message=cross-site request blocked by CSRF`,
},
{
name: "ok, unsafe POST + cross-site + multiple trusted origins, second one matches",
givenConfig: CSRFConfig{
TrustedOrigins: []string{"https://first.example.com", "https://second.example.com"},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "cross-site",
whenOrigin: "https://second.example.com",
expectAllow: true,
},
{
name: "ok, unsafe POST + same-site + custom func allows",
givenConfig: CSRFConfig{
AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
return true, nil
},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "same-site",
expectAllow: true,
},
{
name: "ok, unsafe POST + cross-site + custom func allows",
givenConfig: CSRFConfig{
AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
return true, nil
},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "cross-site",
expectAllow: true,
},
{
name: "nok, unsafe POST + same-site + custom func returns custom error",
givenConfig: CSRFConfig{
AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
return false, echo.NewHTTPError(http.StatusTeapot, "custom error from func")
},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "same-site",
expectAllow: false,
expectErr: `code=418, message=custom error from func`,
},
{
name: "nok, unsafe POST + cross-site + custom func returns false with nil error",
givenConfig: CSRFConfig{
AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
return false, nil
},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "cross-site",
expectAllow: false,
expectErr: "", // custom func returns nil error, so no error expected
},
{
name: "nok, unsafe POST + invalid Sec-Fetch-Site value treated as cross-site",
givenConfig: CSRFConfig{},
whenMethod: http.MethodPost,
whenSecFetchSite: "invalid-value",
expectAllow: false,
expectErr: `code=403, message=cross-site request blocked by CSRF`,
},
{
name: "ok, unsafe POST + cross-site + trusted origin takes precedence over custom func",
givenConfig: CSRFConfig{
TrustedOrigins: []string{"https://trusted.example.com"},
AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
return false, echo.NewHTTPError(http.StatusTeapot, "should not be called")
},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "cross-site",
whenOrigin: "https://trusted.example.com",
expectAllow: true,
},
{
name: "nok, unsafe POST + cross-site + trusted origin not matched, custom func blocks",
givenConfig: CSRFConfig{
TrustedOrigins: []string{"https://trusted.example.com"},
AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
return false, echo.NewHTTPError(http.StatusTeapot, "custom block")
},
},
whenMethod: http.MethodPost,
whenSecFetchSite: "cross-site",
whenOrigin: "https://evil.example.com",
expectAllow: false,
expectErr: `code=418, message=custom block`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.whenMethod, "/", nil)
if tc.whenSecFetchSite != "" {
req.Header.Set(echo.HeaderSecFetchSite, tc.whenSecFetchSite)
}
if tc.whenOrigin != "" {
req.Header.Set(echo.HeaderOrigin, tc.whenOrigin)
}
res := httptest.NewRecorder()
c := echo.NewContext(req, res)
allow, err := tc.givenConfig.checkSecFetchSiteRequest(c)
assert.Equal(t, tc.expectAllow, allow)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
})
}
}
================================================
FILE: middleware/decompress.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"compress/gzip"
"io"
"net/http"
"sync"
"github.com/labstack/echo/v5"
)
// DecompressConfig defines the config for Decompress middleware.
type DecompressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor
// MaxDecompressedSize limits the maximum size of decompressed request body in bytes.
// If the decompressed body exceeds this limit, the middleware returns HTTP 413 error.
// This prevents zip bomb attacks where small compressed payloads decompress to huge sizes.
// Default: 100 * MB (104,857,600 bytes)
// Set to -1 to disable limits (not recommended in production).
MaxDecompressedSize int64
}
// GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"
// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
type Decompressor interface {
gzipDecompressPool() sync.Pool
}
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}
func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
return sync.Pool{New: func() any { return new(gzip.Reader) }}
}
// Decompress decompresses request body based if content encoding type is set to "gzip" with default config
//
// SECURITY: By default, this limits decompressed data to 100MB to prevent zip bomb attacks.
// To customize the limit, use DecompressWithConfig. To disable limits (not recommended in production),
// set MaxDecompressedSize to -1.
func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DecompressConfig{})
}
// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration.
//
// SECURITY: If MaxDecompressedSize is not set (zero value), it defaults to 100MB to prevent
// DoS attacks via zip bombs. Set to -1 to explicitly disable limits if needed for your use case.
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration
func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
if config.GzipDecompressPool == nil {
config.GzipDecompressPool = &DefaultGzipDecompressPool{}
}
// Apply secure default for decompression limit
if config.MaxDecompressedSize == 0 {
config.MaxDecompressedSize = 100 * MB
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
pool := config.GzipDecompressPool.gzipDecompressPool()
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
if c.Request().Header.Get(echo.HeaderContentEncoding) != GZIPEncoding {
return next(c)
}
i := pool.Get()
gr, ok := i.(*gzip.Reader)
if !ok || gr == nil {
if err, isErr := i.(error); isErr {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
return echo.NewHTTPError(http.StatusInternalServerError, "unexpected type from gzip decompression pool")
}
defer pool.Put(gr)
b := c.Request().Body
defer b.Close()
if err := gr.Reset(b); err != nil {
if err == io.EOF { //ignore if body is empty
return next(c)
}
return err
}
// only Close gzip reader if it was set to a proper gzip source otherwise it will panic on close.
defer gr.Close()
// Apply decompression size limit to prevent zip bombs
if config.MaxDecompressedSize > 0 {
c.Request().Body = &limitedGzipReader{
Reader: gr,
remaining: config.MaxDecompressedSize,
limit: config.MaxDecompressedSize,
}
} else {
// -1 means explicitly unlimited (not recommended)
c.Request().Body = gr
}
return next(c)
}
}, nil
}
// limitedGzipReader wraps a gzip reader with size limiting to prevent zip bombs
type limitedGzipReader struct {
*gzip.Reader
remaining int64
limit int64
}
func (r *limitedGzipReader) Read(p []byte) (n int, err error) {
if r.remaining <= 0 {
// Limit exceeded - return 413 error
return 0, echo.ErrStatusRequestEntityTooLarge
}
// Limit the read to remaining bytes
if int64(len(p)) > r.remaining {
p = p[:r.remaining]
}
n, err = r.Reader.Read(p)
r.remaining -= int64(n)
return n, err
}
func (r *limitedGzipReader) Close() error {
return r.Reader.Close()
}
================================================
FILE: middleware/decompress_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"bytes"
"compress/gzip"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestDecompress(t *testing.T) {
e := echo.New()
h := Decompress()(func(c *echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
// Decompress request body
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := io.ReadAll(req.Body)
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}
func TestDecompress_skippedIfNoHeader(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Skip if no Content-Encoding header
h := Decompress()(func(c *echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
err := h(c)
assert.NoError(t, err)
assert.Equal(t, "test", rec.Body.String())
}
func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h, err := DecompressConfig{}.ToMiddleware()
assert.NoError(t, err)
err = h(func(c *echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})(c)
assert.NoError(t, err)
assert.Equal(t, "test", rec.Body.String())
}
func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
e := echo.New()
h := Decompress()(func(c *echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := io.ReadAll(req.Body)
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}
func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
e := echo.New()
body := `{"name":"echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := io.ReadAll(req.Body)
assert.NoError(t, err)
assert.NotEqual(t, b, body)
assert.Equal(t, b, gz)
}
func TestDecompressNoContent(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Decompress()(func(c *echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
err := h(c)
if assert.NoError(t, err) {
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
assert.Equal(t, 0, len(rec.Body.Bytes()))
}
}
func TestDecompressErrorReturned(t *testing.T) {
e := echo.New()
e.Use(Decompress())
e.GET("/", func(c *echo.Context) error {
return echo.ErrNotFound
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Code)
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
func TestDecompressSkipper(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
Skipper: func(c *echo.Context) bool {
return c.Request().URL.Path == "/skip"
},
}))
body := `{"name": "echo"}`
req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSON)
reqBody, err := io.ReadAll(c.Request().Body)
assert.NoError(t, err)
assert.Equal(t, body, string(reqBody))
}
type TestDecompressPoolWithError struct {
}
func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() any {
return errors.New("pool error")
},
}
}
func TestDecompressPoolError(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &TestDecompressPoolWithError{},
}))
body := `{"name": "echo"}`
req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := io.ReadAll(c.Request().Body)
assert.NoError(t, err)
assert.Equal(t, body, string(reqBody))
assert.Equal(t, rec.Code, http.StatusInternalServerError)
}
func BenchmarkDecompress(b *testing.B) {
e := echo.New()
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
h := Decompress()(func(c *echo.Context) error {
c.Response().Write([]byte(body)) // For Content-Type sniffing
return nil
})
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Decompress
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h(c)
}
}
func gzipString(body string) ([]byte, error) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, err := gz.Write([]byte(body))
if err != nil {
return nil, err
}
if err := gz.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func TestDecompress_WithinLimit(t *testing.T) {
e := echo.New()
body := strings.Repeat("test data ", 100) // Small payload ~1KB
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h, err := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware()
assert.NoError(t, err)
err = h(func(c *echo.Context) error {
b, _ := io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, string(b))
})(c)
assert.NoError(t, err)
assert.Equal(t, body, rec.Body.String())
}
func TestDecompress_ExceedsLimit(t *testing.T) {
e := echo.New()
// Create 2KB of data but limit to 1KB
largeBody := strings.Repeat("A", 2*1024)
gz, _ := gzipString(largeBody)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit
assert.NoError(t, err)
err = h(func(c *echo.Context) error {
_, readErr := io.ReadAll(c.Request().Body)
return readErr
})(c)
// Should return 413 error
assert.Error(t, err)
he, ok := err.(echo.HTTPStatusCoder)
assert.True(t, ok)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
}
func TestDecompress_AtExactLimit(t *testing.T) {
e := echo.New()
exactBody := strings.Repeat("B", 1024) // Exactly 1KB
gz, _ := gzipString(exactBody)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware()
assert.NoError(t, err)
err = h(func(c *echo.Context) error {
b, _ := io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, string(b))
})(c)
assert.NoError(t, err)
assert.Equal(t, exactBody, rec.Body.String())
}
func TestDecompress_ZipBomb(t *testing.T) {
e := echo.New()
// Create highly compressed data that expands to 2MB
// but limit is 1MB
largeBody := bytes.Repeat([]byte("A"), 2*1024*1024) // 2MB
var buf bytes.Buffer
gzWriter := gzip.NewWriter(&buf)
gzWriter.Write(largeBody)
gzWriter.Close()
req := httptest.NewRequest(http.MethodPost, "/", &buf)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware()
assert.NoError(t, err)
err = h(func(c *echo.Context) error {
_, readErr := io.ReadAll(c.Request().Body)
return readErr
})(c)
// Should return 413 error
assert.Error(t, err)
he, ok := err.(echo.HTTPStatusCoder)
assert.True(t, ok)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
}
func TestDecompress_UnlimitedExplicit(t *testing.T) {
e := echo.New()
largeBody := strings.Repeat("X", 10*1024) // 10KB
gz, _ := gzipString(largeBody)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h, err := DecompressConfig{MaxDecompressedSize: -1}.ToMiddleware() // Unlimited
assert.NoError(t, err)
err = h(func(c *echo.Context) error {
b, _ := io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, string(b))
})(c)
assert.NoError(t, err)
assert.Equal(t, largeBody, rec.Body.String())
}
func TestDecompress_DefaultLimit(t *testing.T) {
e := echo.New()
smallBody := "test"
gz, _ := gzipString(smallBody)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Use zero value which should apply 100MB default
h, err := DecompressConfig{}.ToMiddleware()
assert.NoError(t, err)
err = h(func(c *echo.Context) error {
b, _ := io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, string(b))
})(c)
assert.NoError(t, err)
assert.Equal(t, smallBody, rec.Body.String())
}
func TestDecompress_SmallCustomLimit(t *testing.T) {
e := echo.New()
body := strings.Repeat("D", 512) // 512 bytes
gz, _ := gzipString(body)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit
assert.NoError(t, err)
err = h(func(c *echo.Context) error {
b, _ := io.ReadAll(c.Request().Body)
return c.String(http.StatusOK, string(b))
})(c)
assert.NoError(t, err)
assert.Equal(t, body, rec.Body.String())
}
func TestDecompress_MultipleReads(t *testing.T) {
e := echo.New()
// Test that limit is enforced across multiple Read() calls
largeBody := strings.Repeat("M", 2*1024) // 2KB
gz, _ := gzipString(largeBody)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit
assert.NoError(t, err)
err = h(func(c *echo.Context) error {
// Read in small chunks
buf := make([]byte, 256)
total := 0
for {
n, readErr := c.Request().Body.Read(buf)
total += n
if readErr != nil {
if readErr == io.EOF {
return nil
}
return readErr
}
}
})(c)
// Should return 413 error from cumulative reads
assert.Error(t, err)
he, ok := err.(echo.HTTPStatusCoder)
assert.True(t, ok)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
}
func TestDecompress_LargePayloadDosPrevention(t *testing.T) {
e := echo.New()
// Simulate a DoS attack with highly compressed large payload
largeSize := 10 * 1024 * 1024 // 10MB decompressed
largeBody := bytes.Repeat([]byte("Z"), largeSize)
var buf bytes.Buffer
gzWriter := gzip.NewWriter(&buf)
gzWriter.Write(largeBody)
gzWriter.Close()
req := httptest.NewRequest(http.MethodPost, "/", &buf)
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware()
assert.NoError(t, err)
err = h(func(c *echo.Context) error {
_, readErr := io.ReadAll(c.Request().Body)
return readErr
})(c)
// Should prevent DoS by returning 413
assert.Error(t, err)
he, ok := err.(echo.HTTPStatusCoder)
assert.True(t, ok)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
}
func BenchmarkDecompress_WithLimit(b *testing.B) {
e := echo.New()
body := strings.Repeat("benchmark data ", 1000) // ~15KB
gz, _ := gzipString(body)
h, _ := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h(func(c *echo.Context) error {
io.ReadAll(c.Request().Body)
return nil
})(c)
}
}
================================================
FILE: middleware/extractor.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"fmt"
"net/textproto"
"strings"
"github.com/labstack/echo/v5"
)
const (
// extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion
// attack vector
extractorLimit = 20
)
// ExtractorSource is type to indicate source for extracted value
type ExtractorSource string
const (
// ExtractorSourceHeader means value was extracted from request header
ExtractorSourceHeader ExtractorSource = "header"
// ExtractorSourceQuery means value was extracted from request query parameters
ExtractorSourceQuery ExtractorSource = "query"
// ExtractorSourcePathParam means value was extracted from route path parameters
ExtractorSourcePathParam ExtractorSource = "param"
// ExtractorSourceCookie means value was extracted from request cookies
ExtractorSourceCookie ExtractorSource = "cookie"
// ExtractorSourceForm means value was extracted from request form values
ExtractorSourceForm ExtractorSource = "form"
)
// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups
type ValueExtractorError struct {
message string
}
// Error returns errors text
func (e *ValueExtractorError) Error() string {
return e.message
}
var errHeaderExtractorValueMissing = &ValueExtractorError{message: "missing value in request header"}
var errHeaderExtractorValueInvalid = &ValueExtractorError{message: "invalid value in request header"}
var errQueryExtractorValueMissing = &ValueExtractorError{message: "missing value in the query string"}
var errParamExtractorValueMissing = &ValueExtractorError{message: "missing value in path params"}
var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing value in cookies"}
var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"}
// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error)
// CreateExtractors creates ValuesExtractors from given lookups.
// lookups is a string in the form of ":" or ":,:" that is used
// to extract key from the request.
// Possible values:
// - "header:" or "header::"
// `` is argument value to cut/trim prefix of the extracted value. This is useful if header
// value has static prefix like `Authorization: ` where part that we
// want to cut is ` ` note the space at the end.
// In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `.
// - "query:"
// - "param:"
// - "form:"
// - "cookie:"
//
// Multiple sources example:
// - "header:Authorization,header:X-Api-Key"
//
// limit sets the maximum amount how many lookups can be returned.
func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error) {
return createExtractors(lookups, limit)
}
func createExtractors(lookups string, limit uint) ([]ValuesExtractor, error) {
if lookups == "" {
return nil, nil
}
if limit == 0 {
limit = 1
} else if limit > extractorLimit {
limit = extractorLimit
}
sources := strings.Split(lookups, ",")
var extractors = make([]ValuesExtractor, 0)
for _, source := range sources {
parts := strings.Split(source, ":")
if len(parts) < 2 {
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
}
switch parts[0] {
case "query":
extractors = append(extractors, valuesFromQuery(parts[1], limit))
case "param":
extractors = append(extractors, valuesFromParam(parts[1], limit))
case "cookie":
extractors = append(extractors, valuesFromCookie(parts[1], limit))
case "form":
extractors = append(extractors, valuesFromForm(parts[1], limit))
case "header":
prefix := ""
if len(parts) > 2 {
prefix = parts[2]
}
extractors = append(extractors, valuesFromHeader(parts[1], prefix, limit))
}
}
return extractors, nil
}
// valuesFromHeader returns a functions that extracts values from the request header.
// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static
// prefix like `Authorization: ` where part that we want to remove is ` `
// note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove
// is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `.
// If prefix is left empty the whole value is returned.
func valuesFromHeader(header string, valuePrefix string, limit uint) ValuesExtractor {
prefixLen := len(valuePrefix)
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
header = textproto.CanonicalMIMEHeaderKey(header)
if limit == 0 {
limit = 1
}
return func(c *echo.Context) ([]string, ExtractorSource, error) {
values := c.Request().Header.Values(header)
if len(values) == 0 {
return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
}
i := uint(0)
result := make([]string, 0)
for _, value := range values {
if prefixLen == 0 {
result = append(result, value)
i++
if i >= limit {
break
}
} else if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
result = append(result, value[prefixLen:])
i++
if i >= limit {
break
}
}
}
if len(result) == 0 {
if prefixLen > 0 {
return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid
}
return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
}
return result, ExtractorSourceHeader, nil
}
}
// valuesFromQuery returns a function that extracts values from the query string.
func valuesFromQuery(param string, limit uint) ValuesExtractor {
if limit == 0 {
limit = 1
}
return func(c *echo.Context) ([]string, ExtractorSource, error) {
result := c.QueryParams()[param]
if len(result) == 0 {
return nil, ExtractorSourceQuery, errQueryExtractorValueMissing
} else if len(result) > int(limit)-1 {
result = result[:limit]
}
return result, ExtractorSourceQuery, nil
}
}
// valuesFromParam returns a function that extracts values from the url param string.
func valuesFromParam(param string, limit uint) ValuesExtractor {
if limit == 0 {
limit = 1
}
return func(c *echo.Context) ([]string, ExtractorSource, error) {
result := make([]string, 0)
i := uint(0)
for _, p := range c.PathValues() {
if param != p.Name {
continue
}
result = append(result, p.Value)
i++
if i >= limit {
break
}
}
if len(result) == 0 {
return nil, ExtractorSourcePathParam, errParamExtractorValueMissing
}
return result, ExtractorSourcePathParam, nil
}
}
// valuesFromCookie returns a function that extracts values from the named cookie.
func valuesFromCookie(name string, limit uint) ValuesExtractor {
if limit == 0 {
limit = 1
}
return func(c *echo.Context) ([]string, ExtractorSource, error) {
cookies := c.Cookies()
if len(cookies) == 0 {
return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
}
i := uint(0)
result := make([]string, 0)
for _, cookie := range cookies {
if name != cookie.Name {
continue
}
result = append(result, cookie.Value)
i++
if i >= limit {
break
}
}
if len(result) == 0 {
return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
}
return result, ExtractorSourceCookie, nil
}
}
// valuesFromForm returns a function that extracts values from the form field.
func valuesFromForm(name string, limit uint) ValuesExtractor {
if limit == 0 {
limit = 1
}
return func(c *echo.Context) ([]string, ExtractorSource, error) {
if c.Request().Form == nil {
_, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory)
}
values := c.Request().Form[name]
if len(values) == 0 {
return nil, ExtractorSourceForm, errFormExtractorValueMissing
}
if len(values) > int(limit)-1 {
values = values[:limit]
}
result := append([]string{}, values...)
return result, ExtractorSourceForm, nil
}
}
================================================
FILE: middleware/extractor_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"bytes"
"fmt"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestCreateExtractors(t *testing.T) {
var testCases = []struct {
name string
givenRequest func() *http.Request
givenPathValues echo.PathValues
whenLookups string
whenLimit uint
expectValues []string
expectSource ExtractorSource
expectCreateError string
expectError string
}{
{
name: "ok, header",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer token")
return req
},
whenLookups: "header:Authorization:Bearer ",
expectValues: []string{"token"},
expectSource: ExtractorSourceHeader,
},
{
name: "ok, form",
givenRequest: func() *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
whenLookups: "form:name",
expectValues: []string{"Jon Snow"},
expectSource: ExtractorSourceForm,
},
{
name: "ok, cookie",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderCookie, "_csrf=token")
return req
},
whenLookups: "cookie:_csrf",
expectValues: []string{"token"},
expectSource: ExtractorSourceCookie,
},
{
name: "ok, param",
givenPathValues: echo.PathValues{
{Name: "id", Value: "123"},
},
whenLookups: "param:id",
expectValues: []string{"123"},
expectSource: ExtractorSourcePathParam,
},
{
name: "ok, query",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/?id=999", nil)
return req
},
whenLookups: "query:id",
expectValues: []string{"999"},
expectSource: ExtractorSourceQuery,
},
{
name: "nok, invalid lookup",
whenLookups: "query",
expectCreateError: "extractor source for lookup could not be split into needed parts: query",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
req = tc.givenRequest()
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if tc.givenPathValues != nil {
c.SetPathValues(tc.givenPathValues)
}
extractors, err := CreateExtractors(tc.whenLookups, tc.whenLimit)
if tc.expectCreateError != "" {
assert.EqualError(t, err, tc.expectCreateError)
return
}
assert.NoError(t, err)
for _, e := range extractors {
values, source, eErr := e(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, tc.expectSource, source)
if tc.expectError != "" {
assert.EqualError(t, eErr, tc.expectError)
return
}
assert.NoError(t, eErr)
}
})
}
}
func TestValuesFromHeader(t *testing.T) {
exampleRequest := func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
}
var testCases = []struct {
name string
givenRequest func(req *http.Request)
whenName string
whenValuePrefix string
whenLimit uint
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
},
{
name: "ok, single value, case insensitive",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "Basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
},
{
name: "ok, multiple value",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
whenLimit: 2,
expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"},
},
{
name: "ok, empty prefix",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "",
expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="},
},
{
name: "nok, no matching due different prefix",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "Bearer ",
expectError: errHeaderExtractorValueInvalid.Error(),
},
{
name: "nok, no matching due different prefix",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderWWWAuthenticate,
whenValuePrefix: "",
expectError: errHeaderExtractorValueMissing.Error(),
},
{
name: "nok, no headers",
givenRequest: nil,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectError: errHeaderExtractorValueMissing.Error(),
},
{
name: "ok, prefix, cut values over extractorLimit",
givenRequest: func(req *http.Request) {
for i := 1; i <= 25; i++ {
req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i))
}
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
whenLimit: extractorLimit,
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
{
name: "ok, cut values over extractorLimit",
givenRequest: func(req *http.Request) {
for i := 1; i <= 25; i++ {
req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i))
}
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "",
whenLimit: extractorLimit,
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix, tc.whenLimit)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourceHeader, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromQuery(t *testing.T) {
var testCases = []struct {
name string
givenQueryPart string
whenName string
whenLimit uint
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenQueryPart: "?id=123&name=test",
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
givenQueryPart: "?id=123&id=456&name=test",
whenName: "id",
whenLimit: 2,
expectValues: []string{"123", "456"},
},
{
name: "nok, missing value",
givenQueryPart: "?id=123&name=test",
whenName: "nope",
expectError: errQueryExtractorValueMissing.Error(),
},
{
name: "ok, cut values over extractorLimit",
givenQueryPart: "?name=test" +
"&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" +
"&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" +
"&id=21&id=22&id=23&id=24&id=25",
whenName: "id",
whenLimit: extractorLimit,
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromQuery(tc.whenName, tc.whenLimit)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourceQuery, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromParam(t *testing.T) {
examplePathValues := echo.PathValues{
{Name: "id", Value: "123"},
{Name: "gid", Value: "456"},
{Name: "gid", Value: "789"},
}
examplePathValues20 := make(echo.PathValues, 0)
for i := 1; i < 25; i++ {
examplePathValues20 = append(examplePathValues20, echo.PathValue{Name: "id", Value: fmt.Sprintf("%v", i)})
}
var testCases = []struct {
name string
givenPathValues echo.PathValues
whenName string
whenLimit uint
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenPathValues: examplePathValues,
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
givenPathValues: examplePathValues,
whenName: "gid",
whenLimit: 2,
expectValues: []string{"456", "789"},
},
{
name: "nok, no values",
givenPathValues: nil,
whenName: "nope",
expectValues: nil,
expectError: errParamExtractorValueMissing.Error(),
},
{
name: "nok, no matching value",
givenPathValues: examplePathValues,
whenName: "nope",
expectValues: nil,
expectError: errParamExtractorValueMissing.Error(),
},
{
name: "ok, cut values over extractorLimit",
givenPathValues: examplePathValues20,
whenName: "id",
whenLimit: extractorLimit,
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if tc.givenPathValues != nil {
c.SetPathValues(tc.givenPathValues)
}
extractor := valuesFromParam(tc.whenName, tc.whenLimit)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourcePathParam, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromCookie(t *testing.T) {
exampleRequest := func(req *http.Request) {
req.Header.Set(echo.HeaderCookie, "_csrf=token")
}
var testCases = []struct {
name string
givenRequest func(req *http.Request)
whenName string
whenLimit uint
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenRequest: exampleRequest,
whenName: "_csrf",
expectValues: []string{"token"},
},
{
name: "ok, multiple value",
givenRequest: func(req *http.Request) {
req.Header.Add(echo.HeaderCookie, "_csrf=token")
req.Header.Add(echo.HeaderCookie, "_csrf=token2")
},
whenName: "_csrf",
whenLimit: 2,
expectValues: []string{"token", "token2"},
},
{
name: "nok, no matching cookie",
givenRequest: exampleRequest,
whenName: "xxx",
expectValues: nil,
expectError: errCookieExtractorValueMissing.Error(),
},
{
name: "nok, no cookies at all",
givenRequest: nil,
whenName: "xxx",
expectValues: nil,
expectError: errCookieExtractorValueMissing.Error(),
},
{
name: "ok, cut values over extractorLimit",
givenRequest: func(req *http.Request) {
for i := 1; i < 25; i++ {
req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i))
}
},
whenName: "_csrf",
whenLimit: extractorLimit,
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromCookie(tc.whenName, tc.whenLimit)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourceCookie, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromForm(t *testing.T) {
examplePostFormRequest := func(mod func(v *url.Values)) *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("emails[]", "jon@labstack.com")
if mod != nil {
mod(&f)
}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
}
exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("emails[]", "jon@labstack.com")
if mod != nil {
mod(&f)
}
req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil)
return req
}
exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request {
var b bytes.Buffer
w := multipart.NewWriter(&b)
w.WriteField("name", "Jon Snow")
w.WriteField("emails[]", "jon@labstack.com")
if mod != nil {
mod(w)
}
fw, _ := w.CreateFormFile("upload", "my.file")
fw.Write([]byte(`hi
`))
w.Close()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String()))
req.Header.Add(echo.HeaderContentType, w.FormDataContentType())
return req
}
var testCases = []struct {
name string
givenRequest *http.Request
whenName string
whenLimit uint
expectValues []string
expectError string
}{
{
name: "ok, POST form, single value",
givenRequest: examplePostFormRequest(nil),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com"},
},
{
name: "ok, POST form, multiple value",
givenRequest: examplePostFormRequest(func(v *url.Values) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
whenLimit: 2,
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "ok, POST multipart/form, multiple value",
givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) {
w.WriteField("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
whenLimit: 2,
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "ok, GET form, single value",
givenRequest: exampleGetFormRequest(nil),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com"},
},
{
name: "ok, GET form, multiple value",
givenRequest: examplePostFormRequest(func(v *url.Values) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
whenLimit: 2,
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "nok, POST form, value missing",
givenRequest: examplePostFormRequest(nil),
whenName: "nope",
expectError: errFormExtractorValueMissing.Error(),
},
{
name: "ok, cut values over extractorLimit",
givenRequest: examplePostFormRequest(func(v *url.Values) {
for i := 1; i < 25; i++ {
v.Add("id[]", fmt.Sprintf("%v", i))
}
}),
whenName: "id[]",
whenLimit: extractorLimit,
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := tc.givenRequest
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromForm(tc.whenName, tc.whenLimit)
values, source, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
assert.Equal(t, ExtractorSourceForm, source)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
================================================
FILE: middleware/key_auth.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"cmp"
"errors"
"fmt"
"net/http"
"github.com/labstack/echo/v5"
)
// KeyAuthConfig defines the config for KeyAuth middleware.
//
// SECURITY: The Validator function is responsible for securely comparing API keys.
// See KeyAuthValidator documentation for guidance on preventing timing attacks.
type KeyAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// KeyLookup is a string in the form of ":" or ":,:" that is used
// to extract key from the request.
// Optional. Default value "header:Authorization".
// Possible values:
// - "header:" or "header::"
// `` is argument value to cut/trim prefix of the extracted value. This is useful if header
// value has static prefix like `Authorization: ` where part that we
// want to cut is ` ` note the space at the end.
// In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `.
// - "query:"
// - "form:"
// - "cookie:"
// Multiple sources example:
// - "header:Authorization,header:X-Api-Key"
KeyLookup string
// AllowedCheckLimit set how many KeyLookup values are allowed to be checked. This is
// useful environments like corporate test environments with application proxies restricting
// access to environment with their own auth scheme.
AllowedCheckLimit uint
// Validator is a function to validate key.
// Required.
Validator KeyAuthValidator
// ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
// function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
// It may be used to define a custom error.
//
// Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
// This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
// In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain.
ErrorHandler KeyAuthErrorHandler
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
// ignore the error (by returning `nil`).
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
// In that case you can use ErrorHandler to set a default public key auth value in the request context
// and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then.
ContinueOnIgnoredError bool
}
// KeyAuthValidator defines a function to validate KeyAuth credentials.
//
// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate
// valid API keys, validator implementations MUST use constant-time comparison.
// Use crypto/subtle.ConstantTimeCompare instead of standard string equality (==)
// or switch statements.
//
// Example of SECURE implementation:
//
// import "crypto/subtle"
//
// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
// // Fetch valid keys from database/config
// validKeys := []string{"key1", "key2", "key3"}
//
// for _, validKey := range validKeys {
// // Use constant-time comparison to prevent timing attacks
// if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
// return true, nil
// }
// }
// return false, nil
// }
//
// Example of INSECURE implementation (DO NOT USE):
//
// // VULNERABLE TO TIMING ATTACKS - DO NOT USE
// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
// switch key { // Timing leak!
// case "valid-key":
// return true, nil
// default:
// return false, nil
// }
// }
type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error)
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
type KeyAuthErrorHandler func(c *echo.Context, err error) error
// ErrKeyMissing denotes an error raised when key value could not be extracted from request
var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
// ErrInvalidKey denotes an error raised when key value is invalid by validator
var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
var DefaultKeyAuthConfig = KeyAuthConfig{
Skipper: DefaultSkipper,
KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
}
// KeyAuth returns an KeyAuth middleware.
//
// For valid key it calls the next handler.
// For invalid key, it sends "401 - Unauthorized" response.
// For missing key, it sends "400 - Bad Request" response.
func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
c := DefaultKeyAuthConfig
c.Validator = fn
return KeyAuthWithConfig(c)
}
// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid.
//
// For first valid key it calls the next handler.
// For invalid key, it sends "401 - Unauthorized" response.
// For missing key, it sends "400 - Bad Request" response.
func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration
func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultKeyAuthConfig.Skipper
}
if config.KeyLookup == "" {
config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
}
if config.Validator == nil {
return nil, errors.New("echo key-auth middleware requires a validator function")
}
limit := cmp.Or(config.AllowedCheckLimit, 1)
extractors, cErr := createExtractors(config.KeyLookup, limit)
if cErr != nil {
return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", cErr)
}
if len(extractors) == 0 {
return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
var lastExtractorErr error
var lastValidatorErr error
for _, extractor := range extractors {
keys, source, extrErr := extractor(c)
if extrErr != nil {
lastExtractorErr = extrErr
continue
}
for _, key := range keys {
valid, err := config.Validator(c, key, source)
if err != nil {
lastValidatorErr = err
continue
}
if !valid {
lastValidatorErr = ErrInvalidKey
continue
}
return next(c)
}
}
// prioritize validator errors over extracting errors
err := lastValidatorErr
if err == nil {
err = lastExtractorErr
}
if config.ErrorHandler != nil {
tmpErr := config.ErrorHandler(c, err)
if config.ContinueOnIgnoredError && tmpErr == nil {
return next(c)
}
return tmpErr
}
if lastValidatorErr == nil {
return ErrKeyMissing.Wrap(err)
}
return echo.ErrUnauthorized.Wrap(err)
}
}, nil
}
================================================
FILE: middleware/key_auth_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"crypto/subtle"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func testKeyValidator(c *echo.Context, key string, source ExtractorSource) (bool, error) {
// Use constant-time comparison to prevent timing attacks
if subtle.ConstantTimeCompare([]byte(key), []byte("valid-key")) == 1 {
return true, nil
}
// Special case for testing error handling
if key == "error-key" { // Error path doesn't need constant-time
return false, errors.New("some user defined error")
}
return false, nil
}
func TestKeyAuth(t *testing.T) {
handlerCalled := false
handler := func(c *echo.Context) error {
handlerCalled = true
return c.String(http.StatusOK, "test")
}
middlewareChain := KeyAuth(testKeyValidator)(handler)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := middlewareChain(c)
assert.NoError(t, err)
assert.True(t, handlerCalled)
}
func TestKeyAuthWithConfig(t *testing.T) {
var testCases = []struct {
name string
givenRequestFunc func() *http.Request
givenRequest func(req *http.Request)
whenConfig func(conf *KeyAuthConfig)
expectHandlerCalled bool
expectError string
}{
{
name: "ok, defaults, key from header",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
},
expectHandlerCalled: true,
},
{
name: "ok, custom skipper",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.Skipper = func(context *echo.Context) bool {
return true
}
},
expectHandlerCalled: true,
},
{
name: "nok, defaults, invalid key from header, Authorization: Bearer",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
},
expectHandlerCalled: false,
expectError: "code=401, message=Unauthorized, err=code=401, message=invalid key",
},
{
name: "nok, defaults, invalid scheme in header",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
},
expectHandlerCalled: false,
expectError: "code=401, message=missing key, err=invalid value in request header",
},
{
name: "nok, defaults, missing header",
givenRequest: func(req *http.Request) {},
expectHandlerCalled: false,
expectError: "code=401, message=missing key, err=missing value in request header",
},
{
name: "ok, custom key lookup, header",
givenRequest: func(req *http.Request) {
req.Header.Set("API-Key", "valid-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:API-Key"
},
expectHandlerCalled: true,
},
{
name: "nok, custom key lookup, missing header",
givenRequest: func(req *http.Request) {
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:API-Key"
},
expectHandlerCalled: false,
expectError: "code=401, message=missing key, err=missing value in request header",
},
{
name: "ok, custom key lookup, query",
givenRequest: func(req *http.Request) {
q := req.URL.Query()
q.Add("key", "valid-key")
req.URL.RawQuery = q.Encode()
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "query:key"
},
expectHandlerCalled: true,
},
{
name: "nok, custom key lookup, missing query param",
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "query:key"
},
expectHandlerCalled: false,
expectError: "code=401, message=missing key, err=missing value in the query string",
},
{
name: "ok, custom key lookup, form",
givenRequestFunc: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("key=valid-key"))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "form:key"
},
expectHandlerCalled: true,
},
{
name: "nok, custom key lookup, missing key in form",
givenRequestFunc: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("xxx=valid-key"))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "form:key"
},
expectHandlerCalled: false,
expectError: "code=401, message=missing key, err=missing value in the form",
},
{
name: "ok, custom key lookup, cookie",
givenRequest: func(req *http.Request) {
req.AddCookie(&http.Cookie{
Name: "key",
Value: "valid-key",
})
q := req.URL.Query()
q.Add("key", "valid-key")
req.URL.RawQuery = q.Encode()
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "cookie:key"
},
expectHandlerCalled: true,
},
{
name: "nok, custom key lookup, missing cookie param",
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "cookie:key"
},
expectHandlerCalled: false,
expectError: "code=401, message=missing key, err=missing value in cookies",
},
{
name: "nok, custom errorHandler, error from extractor",
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "header:token"
conf.ErrorHandler = func(c *echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err)
}
},
expectHandlerCalled: false,
expectError: "code=418, message=custom, err=missing value in request header",
},
{
name: "nok, custom errorHandler, error from validator",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.ErrorHandler = func(c *echo.Context, err error) error {
return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err)
}
},
expectHandlerCalled: false,
expectError: "code=418, message=custom, err=some user defined error",
},
{
name: "nok, defaults, error from validator",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
},
whenConfig: func(conf *KeyAuthConfig) {},
expectHandlerCalled: false,
expectError: "code=401, message=Unauthorized, err=some user defined error",
},
{
name: "ok, custom validator checks source",
givenRequest: func(req *http.Request) {
q := req.URL.Query()
q.Add("key", "valid-key")
req.URL.RawQuery = q.Encode()
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "query:key"
conf.Validator = func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
if source == ExtractorSourceQuery {
return true, nil
}
return false, errors.New("invalid source")
}
},
expectHandlerCalled: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handlerCalled := false
handler := func(c *echo.Context) error {
handlerCalled = true
return c.String(http.StatusOK, "test")
}
config := KeyAuthConfig{
Validator: testKeyValidator,
}
if tc.whenConfig != nil {
tc.whenConfig(&config)
}
middlewareChain := KeyAuthWithConfig(config)(handler)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequestFunc != nil {
req = tc.givenRequestFunc()
}
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := middlewareChain(c)
assert.Equal(t, tc.expectHandlerCalled, handlerCalled)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestKeyAuthWithConfig_errors(t *testing.T) {
var testCases = []struct {
name string
whenConfig KeyAuthConfig
expectError string
}{
{
name: "ok, no error",
whenConfig: KeyAuthConfig{
Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
return false, nil
},
},
},
{
name: "ok, missing validator func",
whenConfig: KeyAuthConfig{
Validator: nil,
},
expectError: "echo key-auth middleware requires a validator function",
},
{
name: "ok, extractor source can not be split",
whenConfig: KeyAuthConfig{
KeyLookup: "nope",
Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
return false, nil
},
},
expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope",
},
{
name: "ok, no extractors",
whenConfig: KeyAuthConfig{
KeyLookup: "nope:nope",
Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
return false, nil
},
},
expectError: "echo key-auth middleware could not create extractors from KeyLookup string",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mw, err := tc.whenConfig.ToMiddleware()
if tc.expectError != "" {
assert.Nil(t, mw)
assert.EqualError(t, err, tc.expectError)
} else {
assert.NotNil(t, mw)
assert.NoError(t, err)
}
})
}
}
func TestMustKeyAuthWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
KeyAuthWithConfig(KeyAuthConfig{})
})
}
func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) {
handlerCalled := false
var authValue string
handler := func(c *echo.Context) error {
handlerCalled = true
authValue = c.Get("auth").(string)
return c.String(http.StatusOK, "test")
}
middlewareChain := KeyAuthWithConfig(KeyAuthConfig{
Validator: testKeyValidator,
ErrorHandler: func(c *echo.Context, err error) error {
// could check error to decide if we can swallow the error
c.Set("auth", "public")
return nil
},
ContinueOnIgnoredError: true,
})(handler)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
// no auth header this time
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := middlewareChain(c)
assert.NoError(t, err)
assert.True(t, handlerCalled)
assert.Equal(t, "public", authValue)
}
================================================
FILE: middleware/method_override.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"net/http"
"github.com/labstack/echo/v5"
)
// MethodOverrideConfig defines the config for MethodOverride middleware.
type MethodOverrideConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Getter is a function that gets overridden method from the request.
// Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride).
Getter MethodOverrideGetter
}
// MethodOverrideGetter is a function that gets overridden method from the request
type MethodOverrideGetter func(c *echo.Context) string
// DefaultMethodOverrideConfig is the default MethodOverride middleware config.
var DefaultMethodOverrideConfig = MethodOverrideConfig{
Skipper: DefaultSkipper,
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
}
// MethodOverride returns a MethodOverride middleware.
// MethodOverride middleware checks for the overridden method from the request and
// uses it instead of the original method.
//
// For security reasons, only `POST` method can be overridden.
func MethodOverride() echo.MiddlewareFunc {
return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
}
// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration.
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration
func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultMethodOverrideConfig.Skipper
}
if config.Getter == nil {
config.Getter = DefaultMethodOverrideConfig.Getter
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
if req.Method == http.MethodPost {
m := config.Getter(c)
if m != "" {
req.Method = m
}
}
return next(c)
}
}, nil
}
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from
// the request header.
func MethodFromHeader(header string) MethodOverrideGetter {
return func(c *echo.Context) string {
return c.Request().Header.Get(header)
}
}
// MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the
// form parameter.
func MethodFromForm(param string) MethodOverrideGetter {
return func(c *echo.Context) string {
return c.FormValue(param)
}
}
// MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from
// the query parameter.
func MethodFromQuery(param string) MethodOverrideGetter {
return func(c *echo.Context) string {
return c.QueryParam(param)
}
}
================================================
FILE: middleware/method_override_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestMethodOverride(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Override with http header
req := httptest.NewRequest(http.MethodPost, "/", nil)
rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
c := e.NewContext(req, rec)
err := m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_formParam(t *testing.T) {
e := echo.New()
h := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Override with form parameter
m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware()
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
c := e.NewContext(req, rec)
err = m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_queryParam(t *testing.T) {
e := echo.New()
h := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Override with query parameter
m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware()
assert.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err = m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodDelete, req.Method)
}
func TestMethodOverride_ignoreGet(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
// Ignore `GET`
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := m(h)(c)
assert.NoError(t, err)
assert.Equal(t, http.MethodGet, req.Method)
}
================================================
FILE: middleware/middleware.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"net/http"
"regexp"
"strconv"
"strings"
"github.com/labstack/echo/v5"
)
// Skipper defines a function to skip middleware. Returning true skips processing the middleware.
type Skipper func(c *echo.Context) bool
// BeforeFunc defines a function which is executed just before the middleware.
type BeforeFunc func(c *echo.Context)
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
groups := pattern.FindAllStringSubmatch(input, -1)
if groups == nil {
return nil
}
values := groups[0][1:]
replace := make([]string, 2*len(values))
for i, v := range values {
j := 2 * i
replace[j] = "$" + strconv.Itoa(i+1)
replace[j+1] = v
}
return strings.NewReplacer(replace...)
}
func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
// Initialize
rulesRegex := map[*regexp.Regexp]string{}
for k, v := range rewrite {
k = regexp.QuoteMeta(k)
k = strings.ReplaceAll(k, `\*`, "(.*?)")
if strings.HasPrefix(k, `\^`) {
k = strings.ReplaceAll(k, `\^`, "^")
}
k = k + "$"
rulesRegex[regexp.MustCompile(k)] = v
}
return rulesRegex
}
func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error {
if len(rewriteRegex) == 0 {
return nil
}
// Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path.
// We only want to use path part for rewriting and therefore trim prefix if it exists
rawURI := req.RequestURI
if rawURI != "" && rawURI[0] != '/' {
prefix := ""
if req.URL.Scheme != "" {
prefix = req.URL.Scheme + "://"
}
if req.URL.Host != "" {
prefix += req.URL.Host // host or host:port
}
if prefix != "" {
rawURI = strings.TrimPrefix(rawURI, prefix)
}
}
for k, v := range rewriteRegex {
if replacer := captureTokens(k, rawURI); replacer != nil {
url, err := req.URL.Parse(replacer.Replace(v))
if err != nil {
return err
}
req.URL = url
return nil // rewrite only once
}
}
return nil
}
// DefaultSkipper returns false which processes the middleware.
func DefaultSkipper(c *echo.Context) bool {
return false
}
func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc {
mw, err := config.ToMiddleware()
if err != nil {
panic(err)
}
return mw
}
================================================
FILE: middleware/middleware_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"bufio"
"errors"
"github.com/stretchr/testify/assert"
"net"
"net/http"
"net/http/httptest"
"regexp"
"testing"
)
func TestRewriteURL(t *testing.T) {
var testCases = []struct {
whenURL string
expectPath string
expectRawPath string
expectQuery string
expectErr string
}{
{
whenURL: "http://localhost:8080/old",
expectPath: "/new",
expectRawPath: "",
},
{ // encoded `ol%64` (decoded `old`) should not be rewritten to `/new`
whenURL: "/ol%64", // `%64` is decoded `d`
expectPath: "/old",
expectRawPath: "/ol%64",
},
{
whenURL: "http://localhost:8080/users/+_+/orders/___++++?test=1",
expectPath: "/user/+_+/order/___++++",
expectRawPath: "",
expectQuery: "test=1",
},
{
whenURL: "http://localhost:8080/users/%20a/orders/%20aa",
expectPath: "/user/ a/order/ aa",
expectRawPath: "",
},
{
whenURL: "http://localhost:8080/%47%6f%2f?test=1",
expectPath: "/Go/",
expectRawPath: "/%47%6f%2f",
expectQuery: "test=1",
},
{
whenURL: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F",
expectPath: "/user/jill/order/T/cO4lW/t/Vp/",
expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
},
{ // do nothing, replace nothing
whenURL: "http://localhost:8080/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
expectPath: "/user/jill/order/T/cO4lW/t/Vp/",
expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
},
{
whenURL: "http://localhost:8080/static",
expectPath: "/static/path",
expectRawPath: "",
expectQuery: "role=AUTHOR&limit=1000",
},
{
whenURL: "/static",
expectPath: "/static/path",
expectRawPath: "",
expectQuery: "role=AUTHOR&limit=1000",
},
}
rules := map[*regexp.Regexp]string{
regexp.MustCompile("^/old$"): "/new",
regexp.MustCompile("^/users/(.*?)/orders/(.*?)$"): "/user/$1/order/$2",
regexp.MustCompile("^/static$"): "/static/path?role=AUTHOR&limit=1000",
}
for _, tc := range testCases {
t.Run(tc.whenURL, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
err := rewriteURL(rules, req)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expectPath, req.URL.Path) // Path field is stored in decoded form: /%47%6f%2f becomes /Go/.
assert.Equal(t, tc.expectRawPath, req.URL.RawPath) // RawPath, an optional field which only gets set if the default encoding is different from Path.
assert.Equal(t, tc.expectQuery, req.URL.RawQuery)
})
}
}
type testResponseWriterNoFlushHijack struct {
}
func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) {
}
func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) {
return 0, nil
}
func (w *testResponseWriterNoFlushHijack) Header() http.Header {
return nil
}
type testResponseWriterUnwrapper struct {
unwrapCalled int
rw http.ResponseWriter
}
func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) {
}
func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) {
return 0, nil
}
func (w *testResponseWriterUnwrapper) Header() http.Header {
return nil
}
func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter {
w.unwrapCalled++
return w.rw
}
type testResponseWriterUnwrapperHijack struct {
testResponseWriterUnwrapper
}
func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("can hijack")
}
================================================
FILE: middleware/proxy.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"net/url"
"regexp"
"strings"
"sync"
"time"
"github.com/labstack/echo/v5"
)
// TODO: Handle TLS proxy
// ProxyConfig defines the config for Proxy middleware.
type ProxyConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Balancer defines a load balancing technique.
// Required.
Balancer ProxyBalancer
// RetryCount defines the number of times a failed proxied request should be retried
// using the next available ProxyTarget. Defaults to 0, meaning requests are never retried.
RetryCount int
// RetryFilter defines a function used to determine if a failed request to a
// ProxyTarget should be retried. The RetryFilter will only be called when the number
// of previous retries is less than RetryCount. If the function returns true, the
// request will be retried. The provided error indicates the reason for the request
// failure. When the ProxyTarget is unavailable, the error will be an instance of
// echo.HTTPError with a code of http.StatusBadGateway. In all other cases, the error
// will indicate an internal error in the Proxy middleware. When a RetryFilter is not
// specified, all requests that fail with http.StatusBadGateway will be retried. A custom
// RetryFilter can be provided to only retry specific requests. Note that RetryFilter is
// only called when the request to the target fails, or an internal error in the Proxy
// middleware has occurred. Successful requests that return a non-200 response code cannot
// be retried.
RetryFilter func(c *echo.Context, e error) bool
// ErrorHandler defines a function which can be used to return custom errors from
// the Proxy middleware. ErrorHandler is only invoked when there has been
// either an internal error in the Proxy middleware or the ProxyTarget is
// unavailable. Due to the way requests are proxied, ErrorHandler is not invoked
// when a ProxyTarget returns a non-200 response. In these cases, the response
// is already written so errors cannot be modified. ErrorHandler is only
// invoked after all retry attempts have been exhausted.
ErrorHandler func(c *echo.Context, err error) error
// Rewrite defines URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
// Examples:
// "/old": "/new",
// "/api/*": "/$1",
// "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
Rewrite map[string]string
// RegexRewrite defines rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example:
// "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1",
RegexRewrite map[*regexp.Regexp]string
// Context key to store selected ProxyTarget into context.
// Optional. Default value "target".
ContextKey string
// To customize the transport to remote.
// Examples: If custom TLS certificates are required.
Transport http.RoundTripper
// ModifyResponse defines function to modify response from ProxyTarget.
ModifyResponse func(*http.Response) error
}
// ProxyTarget defines the upstream target.
type ProxyTarget struct {
Name string
URL *url.URL
Meta map[string]any
}
// ProxyBalancer defines an interface to implement a load balancing technique.
type ProxyBalancer interface {
AddTarget(target *ProxyTarget) bool
RemoveTarget(targetName string) bool
Next(c *echo.Context) (*ProxyTarget, error)
}
type commonBalancer struct {
targets []*ProxyTarget
mutex sync.Mutex
}
// RandomBalancer implements a random load balancing technique.
type randomBalancer struct {
commonBalancer
random *rand.Rand
}
// RoundRobinBalancer implements a round-robin load balancing technique.
type roundRobinBalancer struct {
commonBalancer
// tracking the index on `targets` slice for the next `*ProxyTarget` to be used
i int
}
// DefaultProxyConfig is the default Proxy middleware config.
var DefaultProxyConfig = ProxyConfig{
Skipper: DefaultSkipper,
ContextKey: "target",
}
func proxyRaw(c *echo.Context, t *ProxyTarget, config ProxyConfig) http.Handler {
var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
if transport, ok := config.Transport.(*http.Transport); ok {
if transport.TLSClientConfig != nil {
d := tls.Dialer{
Config: transport.TLSClientConfig,
}
dialFunc = d.DialContext
}
}
if dialFunc == nil {
var d net.Dialer
dialFunc = d.DialContext
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
in, _, err := http.NewResponseController(w).Hijack()
if err != nil {
c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
return
}
defer in.Close()
out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
return
}
defer out.Close()
// Write header
err = r.Write(out)
if err != nil {
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL)))
return
}
errCh := make(chan error, 2)
cp := func(dst io.Writer, src io.Reader) {
_, copyErr := io.Copy(dst, src)
errCh <- copyErr
}
go cp(out, in)
go cp(in, out)
// Wait for BOTH goroutines to complete
err1 := <-errCh
err2 := <-errCh
if err1 != nil && err1 != io.EOF {
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err1, t.URL))
} else if err2 != nil && err2 != io.EOF {
c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err2, t.URL))
}
})
}
// NewRandomBalancer returns a random proxy balancer.
func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer {
b := randomBalancer{}
b.targets = targets
// G404 (CWE-338): Use of weak random number generator (math/rand or math/rand/v2 instead of crypto/rand)
// this random is used to select next target. I can not think of reason this must be cryptographically safe. If you can - please open PR.
b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) // #nosec G404
return &b
}
// NewRoundRobinBalancer returns a round-robin proxy balancer.
func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer {
b := roundRobinBalancer{}
b.targets = targets
return &b
}
// AddTarget adds an upstream target to the list and returns `true`.
//
// However, if a target with the same name already exists then the operation is aborted returning `false`.
func (b *commonBalancer) AddTarget(target *ProxyTarget) bool {
b.mutex.Lock()
defer b.mutex.Unlock()
for _, t := range b.targets {
if t.Name == target.Name {
return false
}
}
b.targets = append(b.targets, target)
return true
}
// RemoveTarget removes an upstream target from the list by name.
//
// Returns `true` on success, `false` if no target with the name is found.
func (b *commonBalancer) RemoveTarget(name string) bool {
b.mutex.Lock()
defer b.mutex.Unlock()
for i, t := range b.targets {
if t.Name == name {
b.targets = append(b.targets[:i], b.targets[i+1:]...)
return true
}
}
return false
}
// Next randomly returns an upstream target.
//
// Note: `nil` is returned in case upstream target list is empty.
func (b *randomBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
b.mutex.Lock()
defer b.mutex.Unlock()
if len(b.targets) == 0 {
return nil, nil
} else if len(b.targets) == 1 {
return b.targets[0], nil
}
return b.targets[b.random.Intn(len(b.targets))], nil
}
// Next returns an upstream target using round-robin technique. In the case
// where a previously failed request is being retried, the round-robin
// balancer will attempt to use the next target relative to the original
// request. If the list of targets held by the balancer is modified while a
// failed request is being retried, it is possible that the balancer will
// return the original failed target.
//
// Note: `nil` is returned in case upstream target list is empty.
func (b *roundRobinBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
b.mutex.Lock()
defer b.mutex.Unlock()
if len(b.targets) == 0 {
return nil, nil
} else if len(b.targets) == 1 {
return b.targets[0], nil
}
var i int
const lastIdxKey = "_round_robin_last_index"
// This request is a retry, start from the index of the previous
// target to ensure we don't attempt to retry the request with
// the same failed target
if c.Get(lastIdxKey) != nil {
i = c.Get(lastIdxKey).(int)
i++
if i >= len(b.targets) {
i = 0
}
} else {
// This is a first time request, use the global index
if b.i >= len(b.targets) {
b.i = 0
}
i = b.i
b.i++
}
c.Set(lastIdxKey, i)
return b.targets[i], nil
}
// Proxy returns a Proxy middleware.
//
// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
c := DefaultProxyConfig
c.Balancer = balancer
return ProxyWithConfig(c)
}
// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid.
//
// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration
func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultProxyConfig.Skipper
}
if config.ContextKey == "" {
config.ContextKey = DefaultProxyConfig.ContextKey
}
if config.Balancer == nil {
return nil, errors.New("echo proxy middleware requires balancer")
}
if config.RetryFilter == nil {
config.RetryFilter = func(c *echo.Context, e error) bool {
if httpErr, ok := e.(*echo.HTTPError); ok {
return httpErr.Code == http.StatusBadGateway
}
return false
}
}
if config.ErrorHandler == nil {
config.ErrorHandler = func(c *echo.Context, err error) error {
return err
}
}
if config.Rewrite != nil {
if config.RegexRewrite == nil {
config.RegexRewrite = make(map[*regexp.Regexp]string)
}
for k, v := range rewriteRulesRegex(config.Rewrite) {
config.RegexRewrite[k] = v
}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
if err := rewriteURL(config.RegexRewrite, req); err != nil {
return config.ErrorHandler(c, err)
}
// Fix header
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
// However, for backward compatibility, legacy behavior is preserved unless you configure Echo#IPExtractor.
if req.Header.Get(echo.HeaderXRealIP) == "" || c.Echo().IPExtractor != nil {
req.Header.Set(echo.HeaderXRealIP, c.RealIP())
}
if req.Header.Get(echo.HeaderXForwardedProto) == "" {
req.Header.Set(echo.HeaderXForwardedProto, c.Scheme())
}
if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
}
retries := config.RetryCount
for {
tgt, err := config.Balancer.Next(c)
if err != nil {
return config.ErrorHandler(c, err)
}
c.Set(config.ContextKey, tgt)
//If retrying a failed request, clear any previous errors from
//context here so that balancers have the option to check for
//errors that occurred using previous target
if retries < config.RetryCount {
c.Set("_error", nil)
}
// This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request
// that Balancer may have replaced with c.SetRequest.
req = c.Request()
// Proxy
switch {
case c.IsWebSocket():
proxyRaw(c, tgt, config).ServeHTTP(res, req)
default: // even SSE requests
proxyHTTP(c, tgt, config).ServeHTTP(res, req)
}
err, hasError := c.Get("_error").(error)
if !hasError {
return nil
}
retry := retries > 0 && config.RetryFilter(c, err)
if !retry {
return config.ErrorHandler(c, err)
}
retries--
}
}
}, nil
}
// StatusCodeContextCanceled is a custom HTTP status code for situations
// where a client unexpectedly closed the connection to the server.
// As there is no standard error code for "client closed connection", but
// various well-known HTTP clients and server implement this HTTP code we use
// 499 too instead of the more problematic 5xx, which does not allow to detect this situation
const StatusCodeContextCanceled = 499
func proxyHTTP(c *echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler {
proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
desc := tgt.URL.String()
if tgt.Name != "" {
desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String())
}
// If the client canceled the request (usually by closing the connection), we can report a
// client error (4xx) instead of a server error (5xx) to correctly identify the situation.
// The Go standard library (at of late 2020) wraps the exported, standard
// context. Canceled error with unexported garbage value requiring a substring check, see
// https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430
// From Caddy https://github.com/caddyserver/caddy/blob/afa778ae05503f563af0d1015cdf7e5e78b1eeec/modules/caddyhttp/reverseproxy/reverseproxy.go#L1352
if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "operation was canceled") {
httpError := echo.NewHTTPError(StatusCodeContextCanceled, "client closed connection").Wrap(err)
c.Set("_error", httpError)
} else {
httpError := echo.NewHTTPError(
http.StatusBadGateway,
"remote server unreachable, could not proxy request",
).Wrap(fmt.Errorf("server: %s, err: %w", desc, err))
c.Set("_error", httpError)
}
}
proxy.Transport = config.Transport
proxy.ModifyResponse = config.ModifyResponse
return proxy
}
================================================
FILE: middleware/proxy_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"sync"
"testing"
"time"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
"golang.org/x/net/websocket"
)
// Assert expected with url.EscapedPath method to obtain the path.
func TestProxy(t *testing.T) {
// Setup
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "target 1")
}))
defer t1.Close()
url1, _ := url.Parse(t1.URL)
t2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "target 2")
}))
defer t2.Close()
url2, _ := url.Parse(t2.URL)
targets := []*ProxyTarget{
{
Name: "target 1",
URL: url1,
},
{
Name: "target 2",
URL: url2,
},
}
rb := NewRandomBalancer(nil)
// must add targets:
for _, target := range targets {
assert.True(t, rb.AddTarget(target))
}
// must ignore duplicates:
for _, target := range targets {
assert.False(t, rb.AddTarget(target))
}
// Random
e := echo.New()
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
body := rec.Body.String()
expected := map[string]bool{
"target 1": true,
"target 2": true,
}
assert.Condition(t, func() bool {
return expected[body]
})
for _, target := range targets {
assert.True(t, rb.RemoveTarget(target.Name))
}
assert.False(t, rb.RemoveTarget("unknown target"))
// Round-robin
rrb := NewRoundRobinBalancer(targets)
e = echo.New()
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
body = rec.Body.String()
assert.Equal(t, "target 1", body)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
body = rec.Body.String()
assert.Equal(t, "target 2", body)
// ModifyResponse
e = echo.New()
e.Use(ProxyWithConfig(ProxyConfig{
Balancer: rrb,
ModifyResponse: func(res *http.Response) error {
res.Body = io.NopCloser(bytes.NewBuffer([]byte("modified")))
res.Header.Set("X-Modified", "1")
return nil
},
}))
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "modified", rec.Body.String())
assert.Equal(t, "1", rec.Header().Get("X-Modified"))
// ProxyTarget is set in context
contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) (err error) {
next(c)
assert.Contains(t, targets, c.Get("target"), "target is not set in context")
return nil
}
}
e = echo.New()
e.Use(contextObserver)
e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)}))
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
}
func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) {
assert.Panics(t, func() {
ProxyWithConfig(ProxyConfig{Balancer: nil})
})
}
func TestProxyRealIPHeader(t *testing.T) {
// Setup
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer upstream.Close()
url, _ := url.Parse(upstream.URL)
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
e := echo.New()
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
remoteAddrIP, _, _ := net.SplitHostPort(req.RemoteAddr)
realIPHeaderIP := "203.0.113.1"
extractedRealIP := "203.0.113.10"
tests := []*struct {
hasRealIPheader bool
hasIPExtractor bool
expectedXRealIP string
}{
{false, false, remoteAddrIP},
{false, true, extractedRealIP},
{true, false, realIPHeaderIP},
{true, true, extractedRealIP},
}
for _, tt := range tests {
if tt.hasRealIPheader {
req.Header.Set(echo.HeaderXRealIP, realIPHeaderIP)
} else {
req.Header.Del(echo.HeaderXRealIP)
}
if tt.hasIPExtractor {
e.IPExtractor = func(*http.Request) string {
return extractedRealIP
}
} else {
e.IPExtractor = nil
}
e.ServeHTTP(rec, req)
assert.Equal(t, tt.expectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor)
}
}
func TestProxyRewrite(t *testing.T) {
var testCases = []struct {
whenPath string
expectProxiedURI string
expectStatus int
}{
{
whenPath: "/api/users",
expectProxiedURI: "/users",
expectStatus: http.StatusOK,
},
{
whenPath: "/js/main.js",
expectProxiedURI: "/public/javascripts/main.js",
expectStatus: http.StatusOK,
},
{
whenPath: "/old",
expectProxiedURI: "/new",
expectStatus: http.StatusOK,
},
{
whenPath: "/users/jack/orders/1",
expectProxiedURI: "/user/jack/order/1",
expectStatus: http.StatusOK,
},
{
whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
expectProxiedURI: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
expectStatus: http.StatusOK,
},
{ // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped when proxying request
whenPath: "/api/new users",
expectProxiedURI: "/new%20users",
expectStatus: http.StatusOK,
},
{ // query params should be proxied and not be modified
whenPath: "/api/users?limit=10",
expectProxiedURI: "/users?limit=10",
expectStatus: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.whenPath, func(t *testing.T) {
receivedRequestURI := make(chan string, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server
// we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic
// if original request had `%2F` we should not magically decode it to `/` as it would change what was requested
receivedRequestURI <- r.RequestURI
}))
defer upstream.Close()
serverURL, _ := url.Parse(upstream.URL)
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: serverURL}})
// Rewrite
e := echo.New()
e.Use(ProxyWithConfig(ProxyConfig{
Balancer: rrb,
Rewrite: map[string]string{
"/old": "/new",
"/api/*": "/$1",
"/js/*": "/public/javascripts/$1",
"/users/*/orders/*": "/user/$1/order/$2",
},
}))
targetURL, _ := serverURL.Parse(tc.whenPath)
req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectStatus, rec.Code)
actualRequestURI := <-receivedRequestURI
assert.Equal(t, tc.expectProxiedURI, actualRequestURI)
})
}
}
func TestProxyRewriteRegex(t *testing.T) {
// Setup
receivedRequestURI := make(chan string, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server
// we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic
// if original request had `%2F` we should not magically decode it to `/` as it would change what was requested
receivedRequestURI <- r.RequestURI
}))
defer upstream.Close()
tmpUrL, _ := url.Parse(upstream.URL)
rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: tmpUrL}})
// Rewrite
e := echo.New()
e.Use(ProxyWithConfig(ProxyConfig{
Balancer: rrb,
Rewrite: map[string]string{
"^/a/*": "/v1/$1",
"^/b/*/c/*": "/v2/$2/$1",
"^/c/*/*": "/v3/$2",
},
RegexRewrite: map[*regexp.Regexp]string{
regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1",
regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1",
},
}))
testCases := []struct {
requestPath string
statusCode int
expectPath string
}{
{"/unmatched", http.StatusOK, "/unmatched"},
{"/a/test", http.StatusOK, "/v1/test"},
{"/b/foo/c/bar/baz", http.StatusOK, "/v2/bar/baz/foo"},
{"/c/ignore/test", http.StatusOK, "/v3/test"},
{"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"},
{"/x/ignore/test", http.StatusOK, "/v4/test"},
{"/y/foo/bar", http.StatusOK, "/v5/bar/foo"},
// NB: fragment is not added by golang httputil.NewSingleHostReverseProxy implementation
// $2 = `bar?q=1#frag`, $1 = `foo`. replaced uri = `/v5/bar?q=1#frag/foo` but httputil.NewSingleHostReverseProxy does not send `#frag/foo` (currently)
{"/y/foo/bar?q=1#frag", http.StatusOK, "/v5/bar?q=1"},
}
for _, tc := range testCases {
t.Run(tc.requestPath, func(t *testing.T) {
targetURL, _ := url.Parse(tc.requestPath)
req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
actualRequestURI := <-receivedRequestURI
assert.Equal(t, tc.expectPath, actualRequestURI)
assert.Equal(t, tc.statusCode, rec.Code)
})
}
}
func TestProxyError(t *testing.T) {
// Setup
url1, _ := url.Parse("http://127.0.0.1:27121")
url2, _ := url.Parse("http://127.0.0.1:27122")
targets := []*ProxyTarget{
{
Name: "target 1",
URL: url1,
},
{
Name: "target 2",
URL: url2,
},
}
rb := NewRandomBalancer(nil)
// must add targets:
for _, target := range targets {
assert.True(t, rb.AddTarget(target))
}
// must ignore duplicates:
for _, target := range targets {
assert.False(t, rb.AddTarget(target))
}
// Random
e := echo.New()
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
// Remote unreachable
rec := httptest.NewRecorder()
req.URL.Path = "/api/users"
e.ServeHTTP(rec, req)
assert.Equal(t, "/api/users", req.URL.Path)
assert.Equal(t, http.StatusBadGateway, rec.Code)
}
func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
var timeoutStop sync.WaitGroup
timeoutStop.Add(1)
HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
timeoutStop.Wait() // wait until we have canceled the request
w.WriteHeader(http.StatusOK)
}))
defer HTTPTarget.Close()
targetURL, _ := url.Parse(HTTPTarget.URL)
target := &ProxyTarget{
Name: "target",
URL: targetURL,
}
rb := NewRandomBalancer(nil)
assert.True(t, rb.AddTarget(target))
e := echo.New()
e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(req.Context())
req = req.WithContext(ctx)
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
}()
e.ServeHTTP(rec, req)
timeoutStop.Done()
assert.Equal(t, 499, rec.Code)
}
type testProvider struct {
commonBalancer
target *ProxyTarget
err error
}
func (p *testProvider) Next(c *echo.Context) (*ProxyTarget, error) {
return p.target, p.err
}
func TestTargetProvider(t *testing.T) {
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "target 1")
}))
defer t1.Close()
url1, _ := url.Parse(t1.URL)
e := echo.New()
tp := &testProvider{}
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
e.Use(Proxy(tp))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
e.ServeHTTP(rec, req)
body := rec.Body.String()
assert.Equal(t, "target 1", body)
}
func TestFailNextTarget(t *testing.T) {
url1, err := url.Parse("http://dummy:8080")
assert.Nil(t, err)
e := echo.New()
tp := &testProvider{}
tp.target = &ProxyTarget{Name: "target 1", URL: url1}
tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target")
e.Use(Proxy(tp))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
e.ServeHTTP(rec, req)
body := rec.Body.String()
assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body)
}
func TestRandomBalancerWithNoTargets(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Assert balancer with empty targets does return `nil` on `Next()`
rb := NewRandomBalancer(nil)
target, err := rb.Next(c)
assert.Nil(t, target)
assert.NoError(t, err)
}
func TestRoundRobinBalancerWithNoTargets(t *testing.T) {
// Assert balancer with empty targets does return `nil` on `Next()`
rrb := NewRoundRobinBalancer([]*ProxyTarget{})
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
target, err := rrb.Next(c)
assert.Nil(t, target)
assert.NoError(t, err)
}
func TestProxyRetries(t *testing.T) {
newServer := func(res int) (*url.URL, *httptest.Server) {
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(res)
}),
)
targetURL, _ := url.Parse(server.URL)
return targetURL, server
}
targetURL, server := newServer(http.StatusOK)
defer server.Close()
goodTarget := &ProxyTarget{
Name: "Good",
URL: targetURL,
}
targetURL, server = newServer(http.StatusBadRequest)
defer server.Close()
goodTargetWith40X := &ProxyTarget{
Name: "Good with 40X",
URL: targetURL,
}
targetURL, _ = url.Parse("http://127.0.0.1:27121")
badTarget := &ProxyTarget{
Name: "Bad",
URL: targetURL,
}
alwaysRetryFilter := func(c *echo.Context, e error) bool { return true }
neverRetryFilter := func(c *echo.Context, e error) bool { return false }
testCases := []struct {
name string
retryCount int
retryFilters []func(c *echo.Context, e error) bool
targets []*ProxyTarget
expectedResponse int
}{
{
name: "retry count 0 does not attempt retry on fail",
targets: []*ProxyTarget{
badTarget,
goodTarget,
},
expectedResponse: http.StatusBadGateway,
},
{
name: "retry count 1 does not attempt retry on success",
retryCount: 1,
targets: []*ProxyTarget{
goodTarget,
},
expectedResponse: http.StatusOK,
},
{
name: "retry count 1 does retry on handler return true",
retryCount: 1,
retryFilters: []func(c *echo.Context, e error) bool{
alwaysRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
goodTarget,
},
expectedResponse: http.StatusOK,
},
{
name: "retry count 1 does not retry on handler return false",
retryCount: 1,
retryFilters: []func(c *echo.Context, e error) bool{
neverRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
goodTarget,
},
expectedResponse: http.StatusBadGateway,
},
{
name: "retry count 2 returns error when no more retries left",
retryCount: 2,
retryFilters: []func(c *echo.Context, e error) bool{
alwaysRetryFilter,
alwaysRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
badTarget,
badTarget,
goodTarget, //Should never be reached as only 2 retries
},
expectedResponse: http.StatusBadGateway,
},
{
name: "retry count 2 returns error when retries left but handler returns false",
retryCount: 3,
retryFilters: []func(c *echo.Context, e error) bool{
alwaysRetryFilter,
alwaysRetryFilter,
neverRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
badTarget,
badTarget,
goodTarget, //Should never be reached as retry handler returns false on 2nd check
},
expectedResponse: http.StatusBadGateway,
},
{
name: "retry count 3 succeeds",
retryCount: 3,
retryFilters: []func(c *echo.Context, e error) bool{
alwaysRetryFilter,
alwaysRetryFilter,
alwaysRetryFilter,
},
targets: []*ProxyTarget{
badTarget,
badTarget,
badTarget,
goodTarget,
},
expectedResponse: http.StatusOK,
},
{
name: "40x responses are not retried",
retryCount: 1,
targets: []*ProxyTarget{
goodTargetWith40X,
goodTarget,
},
expectedResponse: http.StatusBadRequest,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
retryFilterCall := 0
retryFilter := func(c *echo.Context, e error) bool {
if len(tc.retryFilters) == 0 {
assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall))
}
retryFilterCall++
nextRetryFilter := tc.retryFilters[0]
tc.retryFilters = tc.retryFilters[1:]
return nextRetryFilter(c, e)
}
e := echo.New()
e.Use(ProxyWithConfig(
ProxyConfig{
Balancer: NewRoundRobinBalancer(tc.targets),
RetryCount: tc.retryCount,
RetryFilter: retryFilter,
},
))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectedResponse, rec.Code)
if len(tc.retryFilters) > 0 {
assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters)))
}
})
}
}
func TestProxyRetryWithBackendTimeout(t *testing.T) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.ResponseHeaderTimeout = time.Millisecond * 500
timeoutBackend := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(1 * time.Second)
w.WriteHeader(404)
}),
)
defer timeoutBackend.Close()
timeoutTargetURL, _ := url.Parse(timeoutBackend.URL)
goodBackend := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}),
)
defer goodBackend.Close()
goodTargetURL, _ := url.Parse(goodBackend.URL)
e := echo.New()
e.Use(ProxyWithConfig(
ProxyConfig{
Transport: transport,
Balancer: NewRoundRobinBalancer([]*ProxyTarget{
{
Name: "Timeout",
URL: timeoutTargetURL,
},
{
Name: "Good",
URL: goodTargetURL,
},
}),
RetryCount: 1,
},
))
var wg sync.WaitGroup
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, 200, rec.Code)
}()
}
wg.Wait()
}
func TestProxyErrorHandler(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
goodURL, _ := url.Parse(server.URL)
defer server.Close()
goodTarget := &ProxyTarget{
Name: "Good",
URL: goodURL,
}
badURL, _ := url.Parse("http://127.0.0.1:27121")
badTarget := &ProxyTarget{
Name: "Bad",
URL: badURL,
}
transformedError := errors.New("a new error")
testCases := []struct {
name string
target *ProxyTarget
errorHandler func(c *echo.Context, e error) error
expectFinalError func(t *testing.T, err error)
}{
{
name: "Error handler not invoked when request success",
target: goodTarget,
errorHandler: func(c *echo.Context, e error) error {
assert.FailNow(t, "error handler should not be invoked")
return e
},
},
{
name: "Error handler invoked when request fails",
target: badTarget,
errorHandler: func(c *echo.Context, e error) error {
httpErr, ok := e.(*echo.HTTPError)
assert.True(t, ok, "expected http error to be passed to handler")
assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler")
return transformedError
},
expectFinalError: func(t *testing.T, err error) {
assert.Equal(t, transformedError, err, "transformed error not returned from proxy")
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Use(ProxyWithConfig(
ProxyConfig{
Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}),
ErrorHandler: tc.errorHandler,
},
))
errorHandlerCalled := false
dheh := echo.DefaultHTTPErrorHandler(false)
e.HTTPErrorHandler = func(c *echo.Context, err error) {
errorHandlerCalled = true
tc.expectFinalError(t, err)
dheh(c, err)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
if !errorHandlerCalled && tc.expectFinalError != nil {
t.Fatalf("error handler was not called")
}
})
}
}
type testContextKey string
type customBalancer struct {
target *ProxyTarget
}
func (b *customBalancer) AddTarget(target *ProxyTarget) bool {
return false
}
func (b *customBalancer) RemoveTarget(name string) bool {
return false
}
func (b *customBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER")
c.SetRequest(c.Request().WithContext(ctx))
return b.target, nil
}
func TestModifyResponseUseContext(t *testing.T) {
server := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}),
)
defer server.Close()
targetURL, _ := url.Parse(server.URL)
e := echo.New()
e.Use(ProxyWithConfig(
ProxyConfig{
Balancer: &customBalancer{
target: &ProxyTarget{
Name: "tst",
URL: targetURL,
},
},
RetryCount: 1,
ModifyResponse: func(res *http.Response) error {
val := res.Request.Context().Value(testContextKey("FROM_BALANCER"))
if valStr, ok := val.(string); ok {
res.Header.Set("FROM_BALANCER", valStr)
}
return nil
},
},
))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "OK", rec.Body.String())
assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
}
func createSimpleWebSocketServer(serveTLS bool) *httptest.Server {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsHandler := func(conn *websocket.Conn) {
defer conn.Close()
for {
var msg string
err := websocket.Message.Receive(conn, &msg)
if err != nil {
return
}
// message back to the client
websocket.Message.Send(conn, msg)
}
}
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
})
if serveTLS {
return httptest.NewTLSServer(handler)
}
return httptest.NewServer(handler)
}
func createSimpleProxyServer(t *testing.T, srv *httptest.Server, serveTLS bool, toTLS bool) *httptest.Server {
e := echo.New()
if toTLS {
// proxy to tls target
tgtURL, _ := url.Parse(srv.URL)
tgtURL.Scheme = "wss"
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
t.Fatal("Default transport is not of type *http.Transport")
}
transport := defaultTransport.Clone()
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport}))
} else {
// proxy to non-TLS target
tgtURL, _ := url.Parse(srv.URL)
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer}))
}
if serveTLS {
// serve proxy server with TLS
ts := httptest.NewTLSServer(e)
return ts
}
// serve proxy server without TLS
ts := httptest.NewServer(e)
return ts
}
// TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection.
func TestProxyWithConfigWebSocketNonTLS2NonTLS(t *testing.T) {
/*
Arrange
*/
// Create a WebSocket test server (non-TLS)
srv := createSimpleWebSocketServer(false)
defer srv.Close()
// create proxy server (non-TLS to non-TLS)
ts := createSimpleProxyServer(t, srv, false, false)
defer ts.Close()
tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "ws"
tsURL.Path = "/"
/*
Act
*/
// Connect to the proxy WebSocket
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
assert.NoError(t, err)
defer wsConn.Close()
// Send message
sendMsg := "Hello, Non TLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)
/*
Assert
*/
// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}
// TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection.
func TestProxyWithConfigWebSocketTLS2TLS(t *testing.T) {
/*
Arrange
*/
// Create a WebSocket test server (TLS)
srv := createSimpleWebSocketServer(true)
defer srv.Close()
// create proxy server (TLS to TLS)
ts := createSimpleProxyServer(t, srv, true, true)
defer ts.Close()
tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "wss"
tsURL.Path = "/"
/*
Act
*/
origin, err := url.Parse(ts.URL)
assert.NoError(t, err)
config := &websocket.Config{
Location: tsURL,
Origin: origin,
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
Version: websocket.ProtocolVersionHybi13,
}
wsConn, err := websocket.DialConfig(config)
assert.NoError(t, err)
defer wsConn.Close()
// Send message
sendMsg := "Hello, TLS to TLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)
// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}
// TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection.
func TestProxyWithConfigWebSocketNonTLS2TLS(t *testing.T) {
/*
Arrange
*/
// Create a WebSocket test server (TLS)
srv := createSimpleWebSocketServer(true)
defer srv.Close()
// create proxy server (Non-TLS to TLS)
ts := createSimpleProxyServer(t, srv, false, true)
defer ts.Close()
tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "ws"
tsURL.Path = "/"
/*
Act
*/
// Connect to the proxy WebSocket
wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
assert.NoError(t, err)
defer wsConn.Close()
// Send message
sendMsg := "Hello, Non TLS to TLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)
/*
Assert
*/
// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}
// TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination)
func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) {
/*
Arrange
*/
// Create a WebSocket test server (non-TLS)
srv := createSimpleWebSocketServer(false)
defer srv.Close()
// create proxy server (TLS to non-TLS)
ts := createSimpleProxyServer(t, srv, true, false)
defer ts.Close()
tsURL, _ := url.Parse(ts.URL)
tsURL.Scheme = "wss"
tsURL.Path = "/"
/*
Act
*/
origin, err := url.Parse(ts.URL)
assert.NoError(t, err)
config := &websocket.Config{
Location: tsURL,
Origin: origin,
TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
Version: websocket.ProtocolVersionHybi13,
}
wsConn, err := websocket.DialConfig(config)
assert.NoError(t, err)
defer wsConn.Close()
// Send message
sendMsg := "Hello, TLS to NoneTLS WebSocket!"
err = websocket.Message.Send(wsConn, sendMsg)
assert.NoError(t, err)
// Read response
var recvMsg string
err = websocket.Message.Receive(wsConn, &recvMsg)
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}
================================================
FILE: middleware/rate_limiter.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"errors"
"math"
"net/http"
"sync"
"time"
"github.com/labstack/echo/v5"
"golang.org/x/time/rate"
)
// RateLimiterStore is the interface to be implemented by custom stores.
type RateLimiterStore interface {
Allow(identifier string) (bool, error)
}
// RateLimiterConfig defines the configuration for the rate limiter
type RateLimiterConfig struct {
Skipper Skipper
BeforeFunc BeforeFunc
// IdentifierExtractor uses *echo.Context to extract the identifier for a visitor
IdentifierExtractor Extractor
// Store defines a store for the rate limiter
Store RateLimiterStore
// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
ErrorHandler func(c *echo.Context, err error) error
// DenyHandler provides a handler to be called when RateLimiter denies access
DenyHandler func(c *echo.Context, identifier string, err error) error
}
// Extractor is used to extract data from *echo.Context
type Extractor func(c *echo.Context) (string, error)
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
// ErrExtractorError denotes an error raised when extractor function is unsuccessful
var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
// DefaultRateLimiterConfig defines default values for RateLimiterConfig
var DefaultRateLimiterConfig = RateLimiterConfig{
Skipper: DefaultSkipper,
IdentifierExtractor: func(ctx *echo.Context) (string, error) {
id := ctx.RealIP()
return id, nil
},
ErrorHandler: func(c *echo.Context, err error) error {
return ErrExtractorError.Wrap(err)
},
DenyHandler: func(c *echo.Context, identifier string, err error) error {
return ErrRateLimitExceeded.Wrap(err)
},
}
/*
RateLimiter returns a rate limiting middleware
e := echo.New()
limiterStore := middleware.NewRateLimiterMemoryStore(20)
e.GET("/rate-limited", func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}, RateLimiter(limiterStore))
*/
func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc {
config := DefaultRateLimiterConfig
config.Store = store
return RateLimiterWithConfig(config)
}
/*
RateLimiterWithConfig returns a rate limiting middleware
e := echo.New()
config := middleware.RateLimiterConfig{
Skipper: DefaultSkipper,
Store: middleware.NewRateLimiterMemoryStore(
middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute}
)
IdentifierExtractor: func(ctx *echo.Context) (string, error) {
id := ctx.RealIP()
return id, nil
},
ErrorHandler: func(ctx *echo.Context, err error) error {
return context.JSON(http.StatusTooManyRequests, nil)
},
DenyHandler: func(ctx *echo.Context, identifier string, err error) error {
return context.JSON(http.StatusForbidden, nil)
},
}
e.GET("/rate-limited", func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}, middleware.RateLimiterWithConfig(config))
*/
func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration
func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultRateLimiterConfig.Skipper
}
if config.IdentifierExtractor == nil {
config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor
}
if config.ErrorHandler == nil {
config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler
}
if config.DenyHandler == nil {
config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
}
if config.Store == nil {
return nil, errors.New("echo rate limiter store configuration must be provided")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
if config.BeforeFunc != nil {
config.BeforeFunc(c)
}
identifier, err := config.IdentifierExtractor(c)
if err != nil {
return config.ErrorHandler(c, err)
}
if allow, allowErr := config.Store.Allow(identifier); !allow {
return config.DenyHandler(c, identifier, allowErr)
}
return next(c)
}
}, nil
}
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
type RateLimiterMemoryStore struct {
visitors map[string]*Visitor
mutex sync.Mutex
rate float64 // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit
burst int
expiresIn time.Duration
lastCleanup time.Time
timeNow func() time.Time
}
// Visitor signifies a unique user's limiter details
type Visitor struct {
*rate.Limiter
lastSeen time.Time
}
/*
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with
the provided rate (as req/s).
for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
Burst and ExpiresIn will be set to default values.
Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate.
Example (with 20 requests/sec):
limiterStore := middleware.NewRateLimiterMemoryStore(20)
*/
func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) {
return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: rateLimit,
})
}
/*
NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore
with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of
the configured rate if not provided or set to 0.
The built-in memory store is usually capable for modest loads. For higher loads other
store implementations should be considered.
Characteristics:
* Concurrency above 100 parallel requests may causes measurable lock contention
* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map
* A high number of requests from a single IP address may cause lock contention
Example:
limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig(
middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minute},
)
*/
func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) {
store = &RateLimiterMemoryStore{}
store.rate = config.Rate
store.burst = config.Burst
store.expiresIn = config.ExpiresIn
if config.ExpiresIn == 0 {
store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn
}
if config.Burst == 0 {
store.burst = int(math.Max(1, math.Ceil(float64(config.Rate))))
}
store.visitors = make(map[string]*Visitor)
store.timeNow = time.Now
store.lastCleanup = store.timeNow()
return
}
// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore
type RateLimiterMemoryStoreConfig struct {
Rate float64 // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached.
ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up
}
// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore
var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{
ExpiresIn: 3 * time.Minute,
}
// Allow implements RateLimiterStore.Allow
func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
store.mutex.Lock()
limiter, exists := store.visitors[identifier]
if !exists {
limiter = new(Visitor)
limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst)
store.visitors[identifier] = limiter
}
now := store.timeNow()
limiter.lastSeen = now
if now.Sub(store.lastCleanup) > store.expiresIn {
store.cleanupStaleVisitors(now)
}
allowed := limiter.AllowN(now, 1)
store.mutex.Unlock()
return allowed, nil
}
/*
cleanupStaleVisitors helps manage the size of the visitors map by removing stale records
of users who haven't visited again after the configured expiry time has elapsed
*/
func (store *RateLimiterMemoryStore) cleanupStaleVisitors(now time.Time) {
for id, visitor := range store.visitors {
if now.Sub(visitor.lastSeen) > store.expiresIn {
delete(store.visitors, id)
}
}
store.lastCleanup = now
}
================================================
FILE: middleware/rate_limiter_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"errors"
"math/rand"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
)
func TestRateLimiter(t *testing.T) {
e := echo.New()
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
testCases := []struct {
id string
expectErr string
}{
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
}
for _, tc := range testCases {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, tc.id)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := mw(handler)(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
}
}
func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
assert.Panics(t, func() {
RateLimiterWithConfig(RateLimiterConfig{})
})
assert.NotPanics(t, func() {
RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
})
}
func TestRateLimiterWithConfig(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
e := echo.New()
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
mw, err := RateLimiterConfig{
IdentifierExtractor: func(c *echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" {
return "", errors.New("invalid identifier")
}
return id, nil
},
DenyHandler: func(ctx *echo.Context, identifier string, err error) error {
return ctx.JSON(http.StatusForbidden, nil)
},
ErrorHandler: func(ctx *echo.Context, err error) error {
return ctx.JSON(http.StatusBadRequest, nil)
},
Store: inMemoryStore,
}.ToMiddleware()
assert.NoError(t, err)
testCases := []struct {
id string
code int
}{
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusForbidden},
{"", http.StatusBadRequest},
{"127.0.0.1", http.StatusForbidden},
{"127.0.0.1", http.StatusForbidden},
}
for _, tc := range testCases {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, tc.id)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, tc.code, rec.Code)
}
}
func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
e := echo.New()
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
mw, err := RateLimiterConfig{
IdentifierExtractor: func(c *echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" {
return "", errors.New("invalid identifier")
}
return id, nil
},
Store: inMemoryStore,
}.ToMiddleware()
assert.NoError(t, err)
testCases := []struct {
id string
expectErr string
}{
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{expectErr: "code=403, message=error while extracting identifier, err=invalid identifier"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
}
for _, tc := range testCases {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, tc.id)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := mw(handler)(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
}
}
func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
{
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
e := echo.New()
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
mw, err := RateLimiterConfig{
Store: inMemoryStore,
}.ToMiddleware()
assert.NoError(t, err)
testCases := []struct {
id string
expectErr string
}{
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
}
for _, tc := range testCases {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, tc.id)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := mw(handler)(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
}
}
}
func TestRateLimiterWithConfig_skipper(t *testing.T) {
e := echo.New()
var beforeFuncRan bool
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
var inMemoryStore = NewRateLimiterMemoryStore(5)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw, err := RateLimiterConfig{
Skipper: func(c *echo.Context) bool {
return true
},
BeforeFunc: func(c *echo.Context) {
beforeFuncRan = true
},
Store: inMemoryStore,
IdentifierExtractor: func(ctx *echo.Context) (string, error) {
return "127.0.0.1", nil
},
}.ToMiddleware()
assert.NoError(t, err)
err = mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, false, beforeFuncRan)
}
func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
e := echo.New()
var beforeFuncRan bool
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
var inMemoryStore = NewRateLimiterMemoryStore(5)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw, err := RateLimiterConfig{
Skipper: func(c *echo.Context) bool {
return false
},
BeforeFunc: func(c *echo.Context) {
beforeFuncRan = true
},
Store: inMemoryStore,
IdentifierExtractor: func(ctx *echo.Context) (string, error) {
return "127.0.0.1", nil
},
}.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c)
assert.Equal(t, true, beforeFuncRan)
}
func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
e := echo.New()
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
var beforeRan bool
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw, err := RateLimiterConfig{
BeforeFunc: func(c *echo.Context) {
beforeRan = true
},
Store: inMemoryStore,
IdentifierExtractor: func(ctx *echo.Context) (string, error) {
return "127.0.0.1", nil
},
}.ToMiddleware()
assert.NoError(t, err)
err = mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, true, beforeRan)
}
func TestRateLimiterMemoryStore_Allow(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3, ExpiresIn: 2 * time.Second})
testCases := []struct {
id string
allowed bool
}{
{"127.0.0.1", true}, // 0 ms
{"127.0.0.1", true}, // 220 ms burst #2
{"127.0.0.1", true}, // 440 ms burst #3
{"127.0.0.1", false}, // 660 ms block
{"127.0.0.1", false}, // 880 ms block
{"127.0.0.1", true}, // 1100 ms next second #1
{"127.0.0.2", true}, // 1320 ms allow other ip
{"127.0.0.1", false}, // 1540 ms no burst
{"127.0.0.1", false}, // 1760 ms no burst
{"127.0.0.1", false}, // 1980 ms no burst
{"127.0.0.1", true}, // 2200 ms no burst
{"127.0.0.1", false}, // 2420 ms no burst
{"127.0.0.1", false}, // 2640 ms no burst
{"127.0.0.1", false}, // 2860 ms no burst
{"127.0.0.1", true}, // 3080 ms no burst
{"127.0.0.1", false}, // 3300 ms no burst
{"127.0.0.1", false}, // 3520 ms no burst
{"127.0.0.1", false}, // 3740 ms no burst
{"127.0.0.1", false}, // 3960 ms no burst
{"127.0.0.1", true}, // 4180 ms no burst
{"127.0.0.1", false}, // 4400 ms no burst
{"127.0.0.1", false}, // 4620 ms no burst
{"127.0.0.1", false}, // 4840 ms no burst
{"127.0.0.1", true}, // 5060 ms no burst
}
for i, tc := range testCases {
t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond)
inMemoryStore.timeNow = func() time.Time {
return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond)
}
allowed, _ := inMemoryStore.Allow(tc.id)
assert.Equal(t, tc.allowed, allowed)
}
}
func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
inMemoryStore.visitors = map[string]*Visitor{
"A": {
Limiter: rate.NewLimiter(1, 3),
lastSeen: time.Now(),
},
"B": {
Limiter: rate.NewLimiter(1, 3),
lastSeen: time.Now().Add(-1 * time.Minute),
},
"C": {
Limiter: rate.NewLimiter(1, 3),
lastSeen: time.Now().Add(-5 * time.Minute),
},
"D": {
Limiter: rate.NewLimiter(1, 3),
lastSeen: time.Now().Add(-10 * time.Minute),
},
}
inMemoryStore.Allow("D")
inMemoryStore.cleanupStaleVisitors(time.Now())
var exists bool
_, exists = inMemoryStore.visitors["A"]
assert.Equal(t, true, exists)
_, exists = inMemoryStore.visitors["B"]
assert.Equal(t, true, exists)
_, exists = inMemoryStore.visitors["C"]
assert.Equal(t, false, exists)
_, exists = inMemoryStore.visitors["D"]
assert.Equal(t, true, exists)
}
func TestNewRateLimiterMemoryStore(t *testing.T) {
testCases := []struct {
rate float64
burst int
expiresIn time.Duration
expectedExpiresIn time.Duration
}{
{1, 3, 5 * time.Second, 5 * time.Second},
{2, 4, 0, 3 * time.Minute},
{1, 5, 10 * time.Minute, 10 * time.Minute},
{3, 7, 0, 3 * time.Minute},
}
for _, tc := range testCases {
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: tc.rate, Burst: tc.burst, ExpiresIn: tc.expiresIn})
assert.Equal(t, tc.rate, store.rate)
assert.Equal(t, tc.burst, store.burst)
assert.Equal(t, tc.expectedExpiresIn, store.expiresIn)
}
}
func TestRateLimiterMemoryStore_FractionalRateDefaultBurst(t *testing.T) {
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: 0.5, // fractional rate should get a burst of at least 1
})
base := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
store.timeNow = func() time.Time {
return base
}
allowed, err := store.Allow("user")
assert.NoError(t, err)
assert.True(t, allowed, "first request should not be blocked")
allowed, err = store.Allow("user")
assert.NoError(t, err)
assert.False(t, allowed, "burst token should be consumed immediately")
store.timeNow = func() time.Time {
return base.Add(2 * time.Second)
}
allowed, err = store.Allow("user")
assert.NoError(t, err)
assert.True(t, allowed, "token should refill for fractional rate after time passes")
}
func generateAddressList(count int) []string {
addrs := make([]string, count)
for i := 0; i < count; i++ {
addrs[i] = randomString(15)
}
return addrs
}
func run(wg *sync.WaitGroup, store RateLimiterStore, addrs []string, max int, b *testing.B) {
for i := 0; i < b.N; i++ {
store.Allow(addrs[rand.Intn(max)])
}
wg.Done()
}
func benchmarkStore(store RateLimiterStore, parallel int, max int, b *testing.B) {
addrs := generateAddressList(max)
wg := &sync.WaitGroup{}
for i := 0; i < parallel; i++ {
wg.Add(1)
go run(wg, store, addrs, max, b)
}
wg.Wait()
}
const (
testExpiresIn = 1000 * time.Millisecond
)
func BenchmarkRateLimiterMemoryStore_1000(b *testing.B) {
var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
benchmarkStore(store, 10, 1000, b)
}
func BenchmarkRateLimiterMemoryStore_10000(b *testing.B) {
var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
benchmarkStore(store, 10, 10000, b)
}
func BenchmarkRateLimiterMemoryStore_100000(b *testing.B) {
var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
benchmarkStore(store, 10, 100000, b)
}
func BenchmarkRateLimiterMemoryStore_conc100_10000(b *testing.B) {
var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
benchmarkStore(store, 100, 10000, b)
}
// TestRateLimiterMemoryStore_TOCTOUFix verifies that the TOCTOU race condition is fixed
// by ensuring timeNow() is only called once per Allow() call
func TestRateLimiterMemoryStore_TOCTOUFix(t *testing.T) {
t.Parallel()
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: 1,
Burst: 1,
ExpiresIn: 2 * time.Second,
})
// Track time calls to verify we use the same time value
timeCallCount := 0
baseTime := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
store.timeNow = func() time.Time {
timeCallCount++
return baseTime
}
// First request - should succeed
allowed, err := store.Allow("127.0.0.1")
assert.NoError(t, err)
assert.True(t, allowed, "First request should be allowed")
// Verify timeNow() was only called once
assert.Equal(t, 1, timeCallCount, "timeNow() should only be called once per Allow()")
}
// TestRateLimiterMemoryStore_ConcurrentAccess verifies rate limiting correctness under concurrent load
func TestRateLimiterMemoryStore_ConcurrentAccess(t *testing.T) {
t.Parallel()
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: 10,
Burst: 5,
ExpiresIn: 5 * time.Second,
})
const goroutines = 50
const requestsPerGoroutine = 20
var wg sync.WaitGroup
var allowedCount, deniedCount int32
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < requestsPerGoroutine; j++ {
allowed, err := store.Allow("test-user")
assert.NoError(t, err)
if allowed {
atomic.AddInt32(&allowedCount, 1)
} else {
atomic.AddInt32(&deniedCount, 1)
}
time.Sleep(time.Millisecond)
}
}()
}
wg.Wait()
totalRequests := goroutines * requestsPerGoroutine
allowed := int(allowedCount)
denied := int(deniedCount)
assert.Equal(t, totalRequests, allowed+denied, "All requests should be processed")
assert.Greater(t, denied, 0, "Some requests should be denied due to rate limiting")
assert.Greater(t, allowed, 0, "Some requests should be allowed")
}
// TestRateLimiterMemoryStore_RaceDetection verifies no data races with high concurrency
// Run with: go test -race ./middleware -run TestRateLimiterMemoryStore_RaceDetection
func TestRateLimiterMemoryStore_RaceDetection(t *testing.T) {
t.Parallel()
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: 100,
Burst: 200,
ExpiresIn: 1 * time.Second,
})
const goroutines = 100
const requestsPerGoroutine = 100
var wg sync.WaitGroup
identifiers := []string{"user1", "user2", "user3", "user4", "user5"}
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(routineID int) {
defer wg.Done()
for j := 0; j < requestsPerGoroutine; j++ {
identifier := identifiers[routineID%len(identifiers)]
_, err := store.Allow(identifier)
assert.NoError(t, err)
}
}(i)
}
wg.Wait()
}
// TestRateLimiterMemoryStore_TimeOrdering verifies time ordering consistency in rate limiting decisions
func TestRateLimiterMemoryStore_TimeOrdering(t *testing.T) {
t.Parallel()
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: 1,
Burst: 2,
ExpiresIn: 5 * time.Second,
})
currentTime := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
store.timeNow = func() time.Time {
return currentTime
}
// First two requests should succeed (burst=2)
allowed1, _ := store.Allow("user1")
assert.True(t, allowed1, "Request 1 should be allowed (burst)")
allowed2, _ := store.Allow("user1")
assert.True(t, allowed2, "Request 2 should be allowed (burst)")
// Third request should be denied
allowed3, _ := store.Allow("user1")
assert.False(t, allowed3, "Request 3 should be denied (burst exhausted)")
// Advance time by 1 second
currentTime = currentTime.Add(1 * time.Second)
// Fourth request should succeed
allowed4, _ := store.Allow("user1")
assert.True(t, allowed4, "Request 4 should be allowed (1 token available)")
}
================================================
FILE: middleware/recover.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"fmt"
"net/http"
"runtime"
"github.com/labstack/echo/v5"
)
// RecoverConfig defines the config for Recover middleware.
type RecoverConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Size of the stack to be printed.
// Optional. Default value 4KB.
StackSize int
// DisableStackAll disables formatting stack traces of all other goroutines
// into buffer after the trace for the current goroutine.
// Optional. Default value false.
DisableStackAll bool
// DisablePrintStack disables printing stack trace.
// Optional. Default value as false.
DisablePrintStack bool
}
// DefaultRecoverConfig is the default Recover middleware config.
var DefaultRecoverConfig = RecoverConfig{
Skipper: DefaultSkipper,
StackSize: 4 << 10, // 4 KB
DisableStackAll: false,
DisablePrintStack: false,
}
// Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler.
func Recover() echo.MiddlewareFunc {
return RecoverWithConfig(DefaultRecoverConfig)
}
// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration.
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration
func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultRecoverConfig.Skipper
}
if config.StackSize == 0 {
config.StackSize = DefaultRecoverConfig.StackSize
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
defer func() {
if r := recover(); r != nil {
if r == http.ErrAbortHandler {
panic(r)
}
tmpErr, ok := r.(error)
if !ok {
tmpErr = fmt.Errorf("%v", r)
}
if !config.DisablePrintStack {
stack := make([]byte, config.StackSize)
length := runtime.Stack(stack, !config.DisableStackAll)
tmpErr = &PanicStackError{Stack: stack[:length], Err: tmpErr}
}
err = tmpErr
}
}()
return next(c)
}
}, nil
}
// PanicStackError is an error type that wraps an error along with its stack trace.
// It is returned when config.DisablePrintStack is set to false.
type PanicStackError struct {
Stack []byte
Err error
}
func (e *PanicStackError) Error() string {
return fmt.Sprintf("[PANIC RECOVER] %s %s", e.Err.Error(), e.Stack)
}
func (e *PanicStackError) Unwrap() error {
return e.Err
}
================================================
FILE: middleware/recover_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"bytes"
"errors"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestRecover(t *testing.T) {
e := echo.New()
buf := new(bytes.Buffer)
e.Logger = slog.New(&discardHandler{})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Recover()(func(c *echo.Context) error {
panic("test")
})
err := h(c)
assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine")
var pse *PanicStackError
if errors.As(err, &pse) {
assert.Contains(t, string(pse.Stack), "middleware/recover.go")
} else {
assert.Fail(t, "not of type PanicStackError")
}
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
assert.Contains(t, buf.String(), "") // nothing is logged
}
func TestRecover_skipper(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
config := RecoverConfig{
Skipper: func(c *echo.Context) bool {
return true
},
}
h := RecoverWithConfig(config)(func(c *echo.Context) error {
panic("testPANIC")
})
var err error
assert.Panics(t, func() {
err = h(c)
})
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
}
func TestRecoverErrAbortHandler(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
h := Recover()(func(c *echo.Context) error {
panic(http.ErrAbortHandler)
})
defer func() {
r := recover()
if r == nil {
assert.Fail(t, "expecting `http.ErrAbortHandler`, got `nil`")
} else {
if err, ok := r.(error); ok {
assert.ErrorIs(t, err, http.ErrAbortHandler)
} else {
assert.Fail(t, "not of error type")
}
}
}()
hErr := h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.NotContains(t, hErr.Error(), "PANIC RECOVER")
}
func TestRecoverWithConfig(t *testing.T) {
var testCases = []struct {
name string
givenNoPanic bool
whenConfig RecoverConfig
expectErrContain string
expectErr string
}{
{
name: "ok, default config",
whenConfig: DefaultRecoverConfig,
expectErrContain: "[PANIC RECOVER] testPANIC goroutine",
},
{
name: "ok, no panic",
givenNoPanic: true,
whenConfig: DefaultRecoverConfig,
expectErrContain: "",
},
{
name: "ok, DisablePrintStack",
whenConfig: RecoverConfig{
DisablePrintStack: true,
},
expectErr: "testPANIC",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
config := tc.whenConfig
h := RecoverWithConfig(config)(func(c *echo.Context) error {
if tc.givenNoPanic {
return nil
}
panic("testPANIC")
})
err := h(c)
if tc.expectErrContain != "" {
assert.Contains(t, err.Error(), tc.expectErrContain)
} else if tc.expectErr != "" {
assert.Contains(t, err.Error(), tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
})
}
}
================================================
FILE: middleware/redirect.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"errors"
"net/http"
"strings"
"github.com/labstack/echo/v5"
)
// RedirectConfig defines the config for Redirect middleware.
type RedirectConfig struct {
// Skipper defines a function to skip middleware.
Skipper
// Status code to be used when redirecting the request.
// Optional. Default value http.StatusMovedPermanently.
Code int
redirect redirectLogic
}
// redirectLogic represents a function that given a scheme, host and uri
// can both: 1) determine if redirect is needed (will set ok accordingly) and
// 2) return the appropriate redirect url.
type redirectLogic func(scheme, host, uri string) (ok bool, url string)
const www = "www."
// RedirectHTTPSConfig is the HTTPS Redirect middleware config.
var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS}
// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config.
var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW}
// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config.
var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW}
// RedirectWWWConfig is the WWW Redirect middleware config.
var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW}
// RedirectNonWWWConfig is the non WWW Redirect middleware config.
var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW}
// HTTPSRedirect redirects http requests to https.
// For example, http://labstack.com will be redirect to https://labstack.com.
//
// Usage `Echo#Pre(HTTPSRedirect())`
func HTTPSRedirect() echo.MiddlewareFunc {
return HTTPSRedirectWithConfig(RedirectHTTPSConfig)
}
// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration.
func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
config.redirect = redirectHTTPS
return toMiddlewareOrPanic(config)
}
// HTTPSWWWRedirect redirects http requests to https www.
// For example, http://labstack.com will be redirect to https://www.labstack.com.
//
// Usage `Echo#Pre(HTTPSWWWRedirect())`
func HTTPSWWWRedirect() echo.MiddlewareFunc {
return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig)
}
// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration.
func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
config.redirect = redirectHTTPSWWW
return toMiddlewareOrPanic(config)
}
// HTTPSNonWWWRedirect redirects http requests to https non www.
// For example, http://www.labstack.com will be redirect to https://labstack.com.
//
// Usage `Echo#Pre(HTTPSNonWWWRedirect())`
func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig)
}
// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration.
func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
config.redirect = redirectNonHTTPSWWW
return toMiddlewareOrPanic(config)
}
// WWWRedirect redirects non www requests to www.
// For example, http://labstack.com will be redirect to http://www.labstack.com.
//
// Usage `Echo#Pre(WWWRedirect())`
func WWWRedirect() echo.MiddlewareFunc {
return WWWRedirectWithConfig(RedirectWWWConfig)
}
// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration.
func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
config.redirect = redirectWWW
return toMiddlewareOrPanic(config)
}
// NonWWWRedirect redirects www requests to non www.
// For example, http://www.labstack.com will be redirect to http://labstack.com.
//
// Usage `Echo#Pre(NonWWWRedirect())`
func NonWWWRedirect() echo.MiddlewareFunc {
return NonWWWRedirectWithConfig(RedirectNonWWWConfig)
}
// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration.
func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
config.redirect = redirectNonWWW
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration
func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
if config.Code == 0 {
config.Code = http.StatusMovedPermanently
}
if config.redirect == nil {
return nil, errors.New("redirectConfig is missing redirect function")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req, scheme := c.Request(), c.Scheme()
host := req.Host
if ok, url := config.redirect(scheme, host, req.RequestURI); ok {
return c.Redirect(config.Code, url)
}
return next(c)
}
}, nil
}
var redirectHTTPS = func(scheme, host, uri string) (bool, string) {
if scheme != "https" {
return true, "https://" + host + uri
}
return false, ""
}
var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) {
// Redirect if not HTTPS OR missing www prefix (needs either fix)
if scheme != "https" || !strings.HasPrefix(host, www) {
host = strings.TrimPrefix(host, www) // Remove www if present to avoid duplication
return true, "https://www." + host + uri
}
return false, ""
}
var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) {
// Redirect if not HTTPS OR has www prefix (needs either fix)
if scheme != "https" || strings.HasPrefix(host, www) {
host = strings.TrimPrefix(host, www)
return true, "https://" + host + uri
}
return false, ""
}
var redirectWWW = func(scheme, host, uri string) (bool, string) {
if !strings.HasPrefix(host, www) {
return true, scheme + "://www." + host + uri
}
return false, ""
}
var redirectNonWWW = func(scheme, host, uri string) (bool, string) {
if strings.HasPrefix(host, www) {
return true, scheme + "://" + host[4:] + uri
}
return false, ""
}
================================================
FILE: middleware/redirect_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
type middlewareGenerator func() echo.MiddlewareFunc
func TestRedirectHTTPSRedirect(t *testing.T) {
var testCases = []struct {
whenHost string
whenHeader http.Header
expectLocation string
expectStatusCode int
}{
{
whenHost: "labstack.com",
expectLocation: "https://labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "labstack.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "",
expectStatusCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.whenHost, func(t *testing.T) {
res := redirectTest(HTTPSRedirect, tc.whenHost, tc.whenHeader)
assert.Equal(t, tc.expectStatusCode, res.Code)
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
})
}
}
func TestRedirectHTTPSWWWRedirect(t *testing.T) {
var testCases = []struct {
whenHost string
whenHeader http.Header
expectLocation string
expectStatusCode int
}{
{
whenHost: "labstack.com",
expectLocation: "https://www.labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "www.labstack.com",
expectLocation: "https://www.labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "a.com",
expectLocation: "https://www.a.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "ip",
expectLocation: "https://www.ip/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "labstack.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "https://www.labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "www.labstack.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "",
expectStatusCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.whenHost, func(t *testing.T) {
res := redirectTest(HTTPSWWWRedirect, tc.whenHost, tc.whenHeader)
assert.Equal(t, tc.expectStatusCode, res.Code)
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
})
}
}
func TestRedirectHTTPSNonWWWRedirect(t *testing.T) {
var testCases = []struct {
whenHost string
whenHeader http.Header
expectLocation string
expectStatusCode int
}{
{
whenHost: "www.labstack.com",
expectLocation: "https://labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "a.com",
expectLocation: "https://a.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "ip",
expectLocation: "https://ip/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "www.labstack.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "https://labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "labstack.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "",
expectStatusCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.whenHost, func(t *testing.T) {
res := redirectTest(HTTPSNonWWWRedirect, tc.whenHost, tc.whenHeader)
assert.Equal(t, tc.expectStatusCode, res.Code)
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
})
}
}
func TestRedirectWWWRedirect(t *testing.T) {
var testCases = []struct {
whenHost string
whenHeader http.Header
expectLocation string
expectStatusCode int
}{
{
whenHost: "labstack.com",
expectLocation: "http://www.labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "a.com",
expectLocation: "http://www.a.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "ip",
expectLocation: "http://www.ip/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "a.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "https://www.a.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "www.ip",
expectLocation: "",
expectStatusCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.whenHost, func(t *testing.T) {
res := redirectTest(WWWRedirect, tc.whenHost, tc.whenHeader)
assert.Equal(t, tc.expectStatusCode, res.Code)
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
})
}
}
func TestRedirectNonWWWRedirect(t *testing.T) {
var testCases = []struct {
whenHost string
whenHeader http.Header
expectLocation string
expectStatusCode int
}{
{
whenHost: "www.labstack.com",
expectLocation: "http://labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "www.a.com",
expectLocation: "http://a.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "www.a.com",
whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
expectLocation: "https://a.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
whenHost: "ip",
expectLocation: "",
expectStatusCode: http.StatusOK,
},
}
for _, tc := range testCases {
t.Run(tc.whenHost, func(t *testing.T) {
res := redirectTest(NonWWWRedirect, tc.whenHost, tc.whenHeader)
assert.Equal(t, tc.expectStatusCode, res.Code)
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
})
}
}
func TestNonWWWRedirectWithConfig(t *testing.T) {
var testCases = []struct {
name string
givenCode int
givenSkipFunc func(c *echo.Context) bool
whenHost string
whenHeader http.Header
expectLocation string
expectStatusCode int
}{
{
name: "usual redirect",
whenHost: "www.labstack.com",
expectLocation: "http://labstack.com/",
expectStatusCode: http.StatusMovedPermanently,
},
{
name: "redirect is skipped",
givenSkipFunc: func(c *echo.Context) bool {
return true // skip always
},
whenHost: "www.labstack.com",
expectLocation: "",
expectStatusCode: http.StatusOK,
},
{
name: "redirect with custom status code",
givenCode: http.StatusSeeOther,
whenHost: "www.labstack.com",
expectLocation: "http://labstack.com/",
expectStatusCode: http.StatusSeeOther,
},
}
for _, tc := range testCases {
t.Run(tc.whenHost, func(t *testing.T) {
middleware := func() echo.MiddlewareFunc {
return NonWWWRedirectWithConfig(RedirectConfig{
Skipper: tc.givenSkipFunc,
Code: tc.givenCode,
})
}
res := redirectTest(middleware, tc.whenHost, tc.whenHeader)
assert.Equal(t, tc.expectStatusCode, res.Code)
assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
})
}
}
func redirectTest(fn middlewareGenerator, host string, header http.Header) *httptest.ResponseRecorder {
e := echo.New()
next := func(c *echo.Context) (err error) {
return c.NoContent(http.StatusOK)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Host = host
if header != nil {
req.Header = header
}
res := httptest.NewRecorder()
c := e.NewContext(req, res)
fn()(next)(c)
return res
}
================================================
FILE: middleware/request_id.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"github.com/labstack/echo/v5"
)
// RequestIDConfig defines the config for RequestID middleware.
type RequestIDConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Generator defines a function to generate an ID.
// Optional. Default value random.String(32).
Generator func() string
// RequestIDHandler defines a function which is executed for a request id.
RequestIDHandler func(c *echo.Context, requestID string)
// TargetHeader defines what header to look for to populate the id.
// Optional. Default value is `X-Request-Id`
TargetHeader string
}
// RequestID returns a middleware that reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when
// the header value is empty, generates that value and sets request ID to response
// as RequestIDConfig.TargetHeader (`X-Request-Id`) value.
func RequestID() echo.MiddlewareFunc {
return RequestIDWithConfig(RequestIDConfig{})
}
// RequestIDWithConfig returns a middleware with given valid config or panics on invalid configuration.
// The middleware reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when the header value is empty,
// generates that value and sets request ID to response as RequestIDConfig.TargetHeader (`X-Request-Id`) value.
func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration
func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
if config.Generator == nil {
config.Generator = createRandomStringGenerator(32)
}
if config.TargetHeader == "" {
config.TargetHeader = echo.HeaderXRequestID
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
rid := req.Header.Get(config.TargetHeader)
if rid == "" {
rid = config.Generator()
}
res.Header().Set(config.TargetHeader, rid)
if config.RequestIDHandler != nil {
config.RequestIDHandler(c, rid)
}
return next(c)
}
}, nil
}
================================================
FILE: middleware/request_id_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestRequestID(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid := RequestID()
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
}
func TestMustRequestIDWithConfig_skipper(t *testing.T) {
e := echo.New()
e.GET("/", func(c *echo.Context) error {
return c.String(http.StatusTeapot, "test")
})
generatorCalled := false
e.Use(RequestIDWithConfig(RequestIDConfig{
Skipper: func(c *echo.Context) bool {
return true
},
Generator: func() string {
generatorCalled = true
return "customGenerator"
},
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
assert.Equal(t, "test", res.Body.String())
assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "")
assert.False(t, generatorCalled)
}
func TestMustRequestIDWithConfig_customGenerator(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid := RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return "customGenerator" },
})
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
}
func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
called := false
rid := RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return "customGenerator" },
RequestIDHandler: func(c *echo.Context, s string) {
called = true
},
})
h := rid(handler)
err := h(c)
assert.NoError(t, err)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
assert.True(t, called)
}
func TestRequestIDWithConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid, err := RequestIDConfig{}.ToMiddleware()
assert.NoError(t, err)
h := rid(handler)
h(c)
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
// Custom generator
rid = RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return "customGenerator" },
})
h = rid(handler)
h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
}
func TestRequestID_IDNotAltered(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRequestID, "")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid := RequestIDWithConfig(RequestIDConfig{})
h := rid(handler)
_ = h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "")
}
func TestRequestIDConfigDifferentHeader(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
rid := RequestIDWithConfig(RequestIDConfig{TargetHeader: echo.HeaderXCorrelationID})
h := rid(handler)
h(c)
assert.Len(t, rec.Header().Get(echo.HeaderXCorrelationID), 32)
// Custom generator and handler
customID := "customGenerator"
calledHandler := false
rid = RequestIDWithConfig(RequestIDConfig{
Generator: func() string { return customID },
TargetHeader: echo.HeaderXCorrelationID,
RequestIDHandler: func(_ *echo.Context, id string) {
calledHandler = true
assert.Equal(t, customID, id)
},
})
h = rid(handler)
h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXCorrelationID), "customGenerator")
assert.True(t, calledHandler)
}
================================================
FILE: middleware/request_logger.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"context"
"errors"
"log/slog"
"net/http"
"time"
"github.com/labstack/echo/v5"
)
// Example for `slog` https://pkg.go.dev/log/slog
// logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogStatus: true,
// LogURI: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
// if v.Error == nil {
// logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST",
// slog.String("uri", v.URI),
// slog.Int("status", v.Status),
// )
// } else {
// logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR",
// slog.String("uri", v.URI),
// slog.Int("status", v.Status),
// slog.String("err", v.Error.Error()),
// )
// }
// return nil
// },
// }))
//
// Example for `fmt.Printf`
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogStatus: true,
// LogURI: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
// if v.Error == nil {
// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status)
// } else {
// fmt.Printf("REQUEST_ERROR: uri: %v, status: %v, err: %v\n", v.URI, v.Status, v.Error)
// }
// return nil
// },
// }))
//
// Example for Zerolog (https://github.com/rs/zerolog)
// logger := zerolog.New(os.Stdout)
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogURI: true,
// LogStatus: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
// if v.Error == nil {
// logger.Info().
// Str("URI", v.URI).
// Int("status", v.Status).
// Msg("request")
// } else {
// logger.Error().
// Err(v.Error).
// Str("URI", v.URI).
// Int("status", v.Status).
// Msg("request error")
// }
// return nil
// },
// }))
//
// Example for Zap (https://github.com/uber-go/zap)
// logger, _ := zap.NewProduction()
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogURI: true,
// LogStatus: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
// if v.Error == nil {
// logger.Info("request",
// zap.String("URI", v.URI),
// zap.Int("status", v.Status),
// )
// } else {
// logger.Error("request error",
// zap.String("URI", v.URI),
// zap.Int("status", v.Status),
// zap.Error(v.Error),
// )
// }
// return nil
// },
// }))
//
// Example for Logrus (https://github.com/sirupsen/logrus)
// log := logrus.New()
// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
// LogURI: true,
// LogStatus: true,
// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
// if v.Error == nil {
// log.WithFields(logrus.Fields{
// "URI": v.URI,
// "status": v.Status,
// }).Info("request")
// } else {
// log.WithFields(logrus.Fields{
// "URI": v.URI,
// "status": v.Status,
// "error": v.Error,
// }).Error("request error")
// }
// return nil
// },
// }))
// RequestLoggerConfig is configuration for Request Logger middleware.
type RequestLoggerConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// BeforeNextFunc defines a function that is called before next middleware or handler is called in chain.
BeforeNextFunc func(c *echo.Context)
// LogValuesFunc defines a function that is called with values extracted by logger from request/response.
// Mandatory.
LogValuesFunc func(c *echo.Context, v RequestLoggerValues) error
// HandleError instructs logger to call global error handler when next middleware/handler returns an error.
// This is useful when you have custom error handler that can decide to use different status codes.
//
// A side-effect of calling global error handler is that now Response has been committed and sent to the client
// and middlewares up in chain can not change Response status code or response body.
HandleError bool
// LogLatency instructs logger to record duration it took to execute rest of the handler chain (next(c) call).
LogLatency bool
// LogProtocol instructs logger to extract request protocol (i.e. `HTTP/1.1` or `HTTP/2`)
LogProtocol bool
// LogRemoteIP instructs logger to extract request remote IP. See `echo.Context.RealIP()` for implementation details.
LogRemoteIP bool
// LogHost instructs logger to extract request host value (i.e. `example.com`)
LogHost bool
// LogMethod instructs logger to extract request method value (i.e. `GET` etc)
LogMethod bool
// LogURI instructs logger to extract request URI (i.e. `/list?lang=en&page=1`)
LogURI bool
// LogURIPath instructs logger to extract request URI path part (i.e. `/list`)
LogURIPath bool
// LogRoutePath instructs logger to extract route path part to which request was matched to (i.e. `/user/:id`)
LogRoutePath bool
// LogRequestID instructs logger to extract request ID from request `X-Request-ID` header or response if request did not have value.
LogRequestID bool
// LogReferer instructs logger to extract request referer values.
LogReferer bool
// LogUserAgent instructs logger to extract request user agent values.
LogUserAgent bool
// LogStatus instructs logger to extract response status code. If handler chain returns an error,
// the status code is extracted from the error satisfying echo.StatusCoder interface.
LogStatus bool
// LogContentLength instructs logger to extract content length header value. Note: this value could be different from
// actual request body size as it could be spoofed etc.
LogContentLength bool
// LogResponseSize instructs logger to extract response content length value. Note: when used with Gzip middleware
// this value may not be always correct.
LogResponseSize bool
// LogHeaders instructs logger to extract given list of headers from request. Note: request can contain more than
// one header with same value so slice of values is been logger for each given header.
//
// Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header
// names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding".
LogHeaders []string
// LogQueryParams instructs logger to extract given list of query parameters from request URI. Note: request can
// contain more than one query parameter with same name so slice of values is been logger for each given query param name.
LogQueryParams []string
// LogFormValues instructs logger to extract given list of form values from request body+URI. Note: request can
// contain more than one form value with same name so slice of values is been logger for each given form value name.
LogFormValues []string
timeNow func() time.Time
}
// RequestLoggerValues contains extracted values from logger.
type RequestLoggerValues struct {
// StartTime is time recorded before next middleware/handler is executed.
StartTime time.Time
// Latency is duration it took to execute rest of the handler chain (next(c) call).
Latency time.Duration
// Protocol is request protocol (i.e. `HTTP/1.1` or `HTTP/2`)
Protocol string
// RemoteIP is request remote IP. See `echo.Context.RealIP()` for implementation details.
RemoteIP string
// Host is request host value (i.e. `example.com`)
Host string
// Method is request method value (i.e. `GET` etc)
Method string
// URI is request URI (i.e. `/list?lang=en&page=1`)
URI string
// URIPath is request URI path part (i.e. `/list`)
URIPath string
// RoutePath is route path part to which request was matched to (i.e. `/user/:id`)
RoutePath string
// RequestID is request ID from request `X-Request-ID` header or response if request did not have value.
RequestID string
// Referer is request referer values.
Referer string
// UserAgent is request user agent values.
UserAgent string
// Status is a response status code. When the handler returns an error satisfying echo.StatusCoder interface, then code from it.
Status int
// Error is error returned from executed handler chain.
Error error
// ContentLength is content length header value. Note: this value could be different from actual request body size
// as it could be spoofed etc.
ContentLength string
// ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct.
ResponseSize int64
// Headers are list of headers from request. Note: request can contain more than one header with same value so slice
// of values is what will be returned/logged for each given header.
// Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header
// names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding".
Headers map[string][]string
// QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter
// with same name so slice of values is what will be returned/logged for each given query param name.
QueryParams map[string][]string
// FormValues are list of form values from request body+URI. Note: request can contain more than one form value with
// same name so slice of values is what will be returned/logged for each given form value name.
FormValues map[string][]string
}
// RequestLoggerWithConfig returns a RequestLogger middleware with config.
func RequestLoggerWithConfig(config RequestLoggerConfig) echo.MiddlewareFunc {
mw, err := config.ToMiddleware()
if err != nil {
panic(err)
}
return mw
}
// ToMiddleware converts RequestLoggerConfig into middleware or returns an error for invalid configuration.
func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
now := time.Now
if config.timeNow != nil {
now = config.timeNow
}
if config.LogValuesFunc == nil {
return nil, errors.New("missing LogValuesFunc callback function for request logger middleware")
}
logHeaders := len(config.LogHeaders) > 0
headers := append([]string(nil), config.LogHeaders...)
for i, v := range headers {
headers[i] = http.CanonicalHeaderKey(v)
}
logQueryParams := len(config.LogQueryParams) > 0
logFormValues := len(config.LogFormValues) > 0
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
start := now()
if config.BeforeNextFunc != nil {
config.BeforeNextFunc(c)
}
err := next(c)
if err != nil && config.HandleError {
// When global error handler writes the error to the client the Response gets "committed". This state can be
// checked with `c.Response().Committed` field.
c.Echo().HTTPErrorHandler(c, err)
}
res := c.Response()
v := RequestLoggerValues{
StartTime: start,
}
if config.LogLatency {
v.Latency = now().Sub(start)
}
if config.LogProtocol {
v.Protocol = req.Proto
}
if config.LogRemoteIP {
v.RemoteIP = c.RealIP()
}
if config.LogHost {
v.Host = req.Host
}
if config.LogMethod {
v.Method = req.Method
}
if config.LogURI {
v.URI = req.RequestURI
}
if config.LogURIPath {
p := req.URL.Path
if p == "" {
p = "/"
}
v.URIPath = p
}
if config.LogRoutePath {
v.RoutePath = c.Path()
}
if config.LogRequestID {
id := req.Header.Get(echo.HeaderXRequestID)
if id == "" {
id = res.Header().Get(echo.HeaderXRequestID)
}
v.RequestID = id
}
if config.LogReferer {
v.Referer = req.Referer()
}
if config.LogUserAgent {
v.UserAgent = req.UserAgent()
}
if config.LogStatus || config.LogResponseSize {
resp, status := echo.ResolveResponseStatus(res, err)
if config.LogStatus {
v.Status = status
}
if config.LogResponseSize {
v.ResponseSize = -1
if resp != nil {
v.ResponseSize = resp.Size
}
}
}
if err != nil {
v.Error = err
}
if config.LogContentLength {
v.ContentLength = req.Header.Get(echo.HeaderContentLength)
}
if logHeaders {
v.Headers = map[string][]string{}
for _, header := range headers {
if values, ok := req.Header[header]; ok {
v.Headers[header] = values
}
}
}
if logQueryParams {
queryParams := c.QueryParams()
v.QueryParams = map[string][]string{}
for _, param := range config.LogQueryParams {
if values, ok := queryParams[param]; ok {
v.QueryParams[param] = values
}
}
}
if logFormValues {
v.FormValues = map[string][]string{}
for _, formValue := range config.LogFormValues {
if values, ok := req.Form[formValue]; ok {
v.FormValues[formValue] = values
}
}
}
if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil {
return errOnLog
}
// in case of HandleError=true we are returning the error that we already have handled with global error handler
// this is deliberate as this error could be useful for upstream middlewares and default global error handler
// will ignore that error when it bubbles up in middleware chain.
// Committed response can be checked in custom error handler with following logic
//
// if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed {
// return
// }
return err
}
}, nil
}
// RequestLogger creates Request Logger middleware with Echo default settings that uses Context.Logger() as logger.
func RequestLogger() echo.MiddlewareFunc {
return RequestLoggerWithConfig(RequestLoggerConfig{
LogLatency: true,
LogRemoteIP: true,
LogHost: true,
LogMethod: true,
LogURI: true,
LogRequestID: true,
LogUserAgent: true,
LogStatus: true,
LogContentLength: true,
LogResponseSize: true,
// forwards error to the global error handler, so it can decide appropriate status code.
// NB: side-effect of that is - request is now "committed" written to the client. Middlewares up in chain can not
// change Response status code or response body.
HandleError: true,
LogValuesFunc: func(c *echo.Context, v RequestLoggerValues) error {
logger := c.Logger()
if v.Error == nil {
logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST",
slog.String("method", v.Method),
slog.String("uri", v.URI),
slog.Int("status", v.Status),
slog.Duration("latency", v.Latency),
slog.String("host", v.Host),
slog.String("bytes_in", v.ContentLength),
slog.Int64("bytes_out", v.ResponseSize),
slog.String("user_agent", v.UserAgent),
slog.String("remote_ip", v.RemoteIP),
slog.String("request_id", v.RequestID),
)
return nil
}
logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR",
slog.String("method", v.Method),
slog.String("uri", v.URI),
slog.Int("status", v.Status),
slog.Duration("latency", v.Latency),
slog.String("host", v.Host),
slog.String("bytes_in", v.ContentLength),
slog.Int64("bytes_out", v.ResponseSize),
slog.String("user_agent", v.UserAgent),
slog.String("remote_ip", v.RemoteIP),
slog.String("request_id", v.RequestID),
slog.String("error", v.Error.Error()),
)
return nil
},
})
}
================================================
FILE: middleware/request_logger_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"bytes"
"encoding/json"
"errors"
"log/slog"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"testing"
"time"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestRequestLoggerOK(t *testing.T) {
old := slog.Default()
t.Cleanup(func() {
slog.SetDefault(old)
})
e := echo.New()
buf := new(bytes.Buffer)
e.Logger = slog.New(slog.NewJSONHandler(buf, nil))
e.Use(RequestLogger())
e.POST("/test", func(c *echo.Context) error {
return c.String(http.StatusTeapot, "OK")
})
reader := strings.NewReader(`{"foo":"bar"}`)
req := httptest.NewRequest(http.MethodPost, "/test", reader)
req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size())))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
req.Header.Set("User-Agent", "curl/7.68.0")
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
logAttrs := map[string]interface{}{}
assert.NoError(t, json.Unmarshal(buf.Bytes(), &logAttrs))
logAttrs["latency"] = 123
logAttrs["time"] = "x"
expect := map[string]interface{}{
"level": "INFO",
"msg": "REQUEST",
"method": "POST",
"uri": "/test",
"status": float64(418),
"bytes_in": "13",
"host": "example.com",
"bytes_out": float64(2),
"user_agent": "curl/7.68.0",
"remote_ip": "8.8.8.8",
"request_id": "",
"time": "x",
"latency": 123,
}
assert.Equal(t, expect, logAttrs)
}
func TestRequestLoggerError(t *testing.T) {
old := slog.Default()
t.Cleanup(func() {
slog.SetDefault(old)
})
e := echo.New()
buf := new(bytes.Buffer)
e.Logger = slog.New(slog.NewJSONHandler(buf, nil))
e.Use(RequestLogger())
e.GET("/test", func(c *echo.Context) error {
return errors.New("nope")
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
logAttrs := map[string]interface{}{}
assert.NoError(t, json.Unmarshal(buf.Bytes(), &logAttrs))
logAttrs["latency"] = 123
logAttrs["time"] = "x"
expect := map[string]interface{}{
"level": "ERROR",
"msg": "REQUEST_ERROR",
"method": "GET",
"uri": "/test",
"status": float64(500),
"bytes_in": "",
"host": "example.com",
"bytes_out": float64(36.0),
"user_agent": "",
"remote_ip": "192.0.2.1",
"request_id": "",
"error": "nope",
"latency": 123,
"time": "x",
}
assert.Equal(t, expect, logAttrs)
}
func TestRequestLoggerWithConfig(t *testing.T) {
e := echo.New()
var expect RequestLoggerValues
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
LogRoutePath: true,
LogURI: true,
LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
expect = values
return nil
},
}))
e.GET("/test", func(c *echo.Context) error {
return c.String(http.StatusTeapot, "OK")
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTeapot, rec.Code)
assert.Equal(t, "/test", expect.RoutePath)
}
func TestRequestLoggerWithConfig_missingOnLogValuesPanics(t *testing.T) {
assert.Panics(t, func() {
RequestLoggerWithConfig(RequestLoggerConfig{
LogValuesFunc: nil,
})
})
}
func TestRequestLogger_skipper(t *testing.T) {
e := echo.New()
loggerCalled := false
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
Skipper: func(c *echo.Context) bool {
return true
},
LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
loggerCalled = true
return nil
},
}))
e.GET("/test", func(c *echo.Context) error {
return c.String(http.StatusTeapot, "OK")
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTeapot, rec.Code)
assert.False(t, loggerCalled)
}
func TestRequestLogger_beforeNextFunc(t *testing.T) {
e := echo.New()
var myLoggerInstance int
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
BeforeNextFunc: func(c *echo.Context) {
c.Set("myLoggerInstance", 42)
},
LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
myLoggerInstance = c.Get("myLoggerInstance").(int)
return nil
},
}))
e.GET("/test", func(c *echo.Context) error {
return c.String(http.StatusTeapot, "OK")
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTeapot, rec.Code)
assert.Equal(t, 42, myLoggerInstance)
}
func TestRequestLogger_logError(t *testing.T) {
e := echo.New()
var actual RequestLoggerValues
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
LogStatus: true,
LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
actual = values
return nil
},
}))
e.GET("/test", func(c *echo.Context) error {
return echo.NewHTTPError(http.StatusNotAcceptable, "nope")
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotAcceptable, rec.Code)
assert.Equal(t, http.StatusNotAcceptable, actual.Status)
assert.EqualError(t, actual.Error, "code=406, message=nope")
}
func TestRequestLogger_HandleError(t *testing.T) {
e := echo.New()
var actual RequestLoggerValues
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
timeNow: func() time.Time {
return time.Unix(1631045377, 0).UTC()
},
HandleError: true,
LogStatus: true,
LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
actual = values
return nil
},
}))
// to see if "HandleError" works we create custom error handler that uses its own status codes
e.HTTPErrorHandler = func(c *echo.Context, err error) {
if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed {
return
}
c.JSON(http.StatusTeapot, "custom error handler")
}
e.GET("/test", func(c *echo.Context) error {
return echo.NewHTTPError(http.StatusForbidden, "nope")
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTeapot, rec.Code)
expect := RequestLoggerValues{
StartTime: time.Unix(1631045377, 0).UTC(),
Status: http.StatusTeapot,
Error: echo.NewHTTPError(http.StatusForbidden, "nope"),
}
assert.Equal(t, expect, actual)
}
func TestRequestLogger_LogValuesFuncError(t *testing.T) {
e := echo.New()
var expect RequestLoggerValues
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
LogStatus: true,
LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
expect = values
return echo.NewHTTPError(http.StatusNotAcceptable, "LogValuesFuncError")
},
}))
e.GET("/test", func(c *echo.Context) error {
return c.String(http.StatusTeapot, "OK")
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
// NOTE: when global error handler received error returned from middleware the status has already
// been written to the client and response has been "committed" therefore global error handler does not do anything
// and error that bubbled up in middleware chain will not be reflected in response code.
assert.Equal(t, http.StatusTeapot, rec.Code)
assert.Equal(t, http.StatusTeapot, expect.Status)
}
func TestRequestLogger_ID(t *testing.T) {
var testCases = []struct {
name string
whenFromRequest bool
expect string
}{
{
name: "ok, ID is provided from request headers",
whenFromRequest: true,
expect: "123",
},
{
name: "ok, ID is from response headers",
whenFromRequest: false,
expect: "321",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
var expect RequestLoggerValues
e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
LogRequestID: true,
LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
expect = values
return nil
},
}))
e.GET("/test", func(c *echo.Context) error {
c.Response().Header().Set(echo.HeaderXRequestID, "321")
return c.String(http.StatusTeapot, "OK")
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
if tc.whenFromRequest {
req.Header.Set(echo.HeaderXRequestID, "123")
}
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTeapot, rec.Code)
assert.Equal(t, tc.expect, expect.RequestID)
})
}
}
func TestRequestLogger_headerIsCaseInsensitive(t *testing.T) {
e := echo.New()
var expect RequestLoggerValues
mw := RequestLoggerWithConfig(RequestLoggerConfig{
LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
expect = values
return nil
},
LogHeaders: []string{"referer", "User-Agent"},
})(func(c *echo.Context) error {
c.Request().Header.Set(echo.HeaderXRequestID, "123")
c.FormValue("to force parse form")
return c.String(http.StatusTeapot, "OK")
})
req := httptest.NewRequest(http.MethodGet, "/test?lang=en&checked=1&checked=2", nil)
req.Header.Set("referer", "https://echo.labstack.com/")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := mw(c)
assert.NoError(t, err)
assert.Len(t, expect.Headers, 1)
assert.Equal(t, []string{"https://echo.labstack.com/"}, expect.Headers["Referer"])
}
func TestRequestLogger_allFields(t *testing.T) {
e := echo.New()
isFirstNowCall := true
var expect RequestLoggerValues
mw := RequestLoggerWithConfig(RequestLoggerConfig{
LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
expect = values
return nil
},
LogLatency: true,
LogProtocol: true,
LogRemoteIP: true,
LogHost: true,
LogMethod: true,
LogURI: true,
LogURIPath: true,
LogRoutePath: true,
LogRequestID: true,
LogReferer: true,
LogUserAgent: true,
LogStatus: true,
LogContentLength: true,
LogResponseSize: true,
LogHeaders: []string{"accept-encoding", "User-Agent"},
LogQueryParams: []string{"lang", "checked"},
LogFormValues: []string{"csrf", "multiple"},
timeNow: func() time.Time {
if isFirstNowCall {
isFirstNowCall = false
return time.Unix(1631045377, 0)
}
return time.Unix(1631045377+10, 0)
},
})(func(c *echo.Context) error {
c.Request().Header.Set(echo.HeaderXRequestID, "123")
c.FormValue("to force parse form")
return c.String(http.StatusTeapot, "OK")
})
f := make(url.Values)
f.Set("csrf", "token")
f.Set("multiple", "1")
f.Add("multiple", "2")
reader := strings.NewReader(f.Encode())
req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader)
req.Header.Set("Referer", "https://echo.labstack.com/")
req.Header.Set("User-Agent", "curl/7.68.0")
req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size())))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.SetPath("/test*")
err := mw(c)
assert.NoError(t, err)
assert.Equal(t, time.Unix(1631045377, 0), expect.StartTime)
assert.Equal(t, 10*time.Second, expect.Latency)
assert.Equal(t, "HTTP/1.1", expect.Protocol)
assert.Equal(t, "8.8.8.8", expect.RemoteIP)
assert.Equal(t, "example.com", expect.Host)
assert.Equal(t, http.MethodPost, expect.Method)
assert.Equal(t, "/test?lang=en&checked=1&checked=2", expect.URI)
assert.Equal(t, "/test", expect.URIPath)
assert.Equal(t, "/test*", expect.RoutePath)
assert.Equal(t, "123", expect.RequestID)
assert.Equal(t, "https://echo.labstack.com/", expect.Referer)
assert.Equal(t, "curl/7.68.0", expect.UserAgent)
assert.Equal(t, 418, expect.Status)
assert.Equal(t, nil, expect.Error)
assert.Equal(t, "32", expect.ContentLength)
assert.Equal(t, int64(2), expect.ResponseSize)
assert.Len(t, expect.Headers, 1)
assert.Equal(t, []string{"curl/7.68.0"}, expect.Headers["User-Agent"])
assert.Len(t, expect.QueryParams, 2)
assert.Equal(t, []string{"en"}, expect.QueryParams["lang"])
assert.Equal(t, []string{"1", "2"}, expect.QueryParams["checked"])
assert.Len(t, expect.FormValues, 2)
assert.Equal(t, []string{"token"}, expect.FormValues["csrf"])
assert.Equal(t, []string{"1", "2"}, expect.FormValues["multiple"])
}
func TestTestRequestLogger(t *testing.T) {
var testCases = []struct {
name string
whenStatus int
whenError error
expectStatus string
expectError string
}{
{
name: "ok",
whenStatus: http.StatusTeapot,
expectStatus: "418",
},
{
name: "error",
whenError: echo.NewHTTPError(http.StatusBadGateway, "bad gw"),
expectStatus: "502",
expectError: `"error":"code=502, message=bad gw"`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
buf := new(bytes.Buffer)
e.Logger = slog.New(slog.NewJSONHandler(buf, nil))
e.Use(RequestLogger())
e.POST("/test", func(c *echo.Context) error {
if tc.whenError != nil {
return tc.whenError
}
return c.String(tc.whenStatus, "OK")
})
f := make(url.Values)
f.Set("csrf", "token")
f.Set("multiple", "1")
f.Add("multiple", "2")
reader := strings.NewReader(f.Encode())
req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader)
req.Header.Set("Referer", "https://echo.labstack.com/")
req.Header.Set("User-Agent", "curl/7.68.0")
req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size())))
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
req.Header.Set(echo.HeaderXRequestID, "MY_ID")
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
rawlog := buf.Bytes()
if tc.expectError != "" {
assert.Contains(t, string(rawlog), `"level":"ERROR"`)
assert.Contains(t, string(rawlog), `"msg":"REQUEST_ERROR"`)
assert.Contains(t, string(rawlog), tc.expectError)
} else {
assert.Contains(t, string(rawlog), `"level":"INFO"`)
assert.Contains(t, string(rawlog), `"msg":"REQUEST"`)
}
assert.Contains(t, string(rawlog), `"status":`+tc.expectStatus)
assert.Contains(t, string(rawlog), `"method":"POST"`)
assert.Contains(t, string(rawlog), `"uri":"/test?lang=en&checked=1&checked=2"`)
assert.Contains(t, string(rawlog), `"latency":`) // this value varies
assert.Contains(t, string(rawlog), `"request_id":"MY_ID"`)
assert.Contains(t, string(rawlog), `"remote_ip":"8.8.8.8"`)
assert.Contains(t, string(rawlog), `"host":"example.com"`)
assert.Contains(t, string(rawlog), `"user_agent":"curl/7.68.0"`)
assert.Contains(t, string(rawlog), `"bytes_in":"32"`)
assert.Contains(t, string(rawlog), `"bytes_out":2`)
})
}
}
func BenchmarkRequestLogger_withoutMapFields(b *testing.B) {
e := echo.New()
mw := RequestLoggerWithConfig(RequestLoggerConfig{
Skipper: nil,
LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
return nil
},
LogLatency: true,
LogProtocol: true,
LogRemoteIP: true,
LogHost: true,
LogMethod: true,
LogURI: true,
LogURIPath: true,
LogRoutePath: true,
LogRequestID: true,
LogReferer: true,
LogUserAgent: true,
LogStatus: true,
LogContentLength: true,
LogResponseSize: true,
})(func(c *echo.Context) error {
c.Request().Header.Set(echo.HeaderXRequestID, "123")
return c.String(http.StatusTeapot, "OK")
})
req := httptest.NewRequest(http.MethodGet, "/test?lang=en", nil)
req.Header.Set("Referer", "https://echo.labstack.com/")
req.Header.Set("User-Agent", "curl/7.68.0")
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw(c)
}
}
func BenchmarkRequestLogger_withMapFields(b *testing.B) {
e := echo.New()
mw := RequestLoggerWithConfig(RequestLoggerConfig{
LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
return nil
},
LogLatency: true,
LogProtocol: true,
LogRemoteIP: true,
LogHost: true,
LogMethod: true,
LogURI: true,
LogURIPath: true,
LogRoutePath: true,
LogRequestID: true,
LogReferer: true,
LogUserAgent: true,
LogStatus: true,
LogContentLength: true,
LogResponseSize: true,
LogHeaders: []string{"accept-encoding", "User-Agent"},
LogQueryParams: []string{"lang", "checked"},
LogFormValues: []string{"csrf", "multiple"},
})(func(c *echo.Context) error {
c.Request().Header.Set(echo.HeaderXRequestID, "123")
c.FormValue("to force parse form")
return c.String(http.StatusTeapot, "OK")
})
f := make(url.Values)
f.Set("csrf", "token")
f.Add("multiple", "1")
f.Add("multiple", "2")
req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode()))
req.Header.Set("Referer", "https://echo.labstack.com/")
req.Header.Set("User-Agent", "curl/7.68.0")
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw(c)
}
}
================================================
FILE: middleware/rewrite.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"errors"
"regexp"
"github.com/labstack/echo/v5"
)
// RewriteConfig defines the config for Rewrite middleware.
type RewriteConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Rules defines the URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
// Example:
// "/old": "/new",
// "/api/*": "/$1",
// "/js/*": "/public/javascripts/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
// Required.
Rules map[string]string
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example:
// "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1",
RegexRules map[*regexp.Regexp]string
}
// Rewrite returns a Rewrite middleware.
//
// Rewrite middleware rewrites the URL path based on the provided rules.
func Rewrite(rules map[string]string) echo.MiddlewareFunc {
c := RewriteConfig{}
c.Rules = rules
return RewriteWithConfig(c)
}
// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration.
//
// Rewrite middleware rewrites the URL path based on the provided rules.
func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
return toMiddlewareOrPanic(config)
}
// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration
func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultSkipper
}
if config.Rules == nil && config.RegexRules == nil {
return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules")
}
if config.RegexRules == nil {
config.RegexRules = make(map[*regexp.Regexp]string)
}
for k, v := range rewriteRulesRegex(config.Rules) {
config.RegexRules[k] = v
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
if err := rewriteURL(config.RegexRules, c.Request()); err != nil {
return err
}
return next(c)
}
}, nil
}
================================================
FILE: middleware/rewrite_test.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"testing"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestRewriteAfterRouting(t *testing.T) {
e := echo.New()
// middlewares added with `Use()` are executed after routing is done and do not affect which route handler is matched
e.Use(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{
"/old": "/new",
"/api/*": "/$1",
"/js/*": "/public/javascripts/$1",
"/users/*/orders/*": "/user/$1/order/$2",
},
}))
e.GET("/public/*", func(c *echo.Context) error {
return c.String(http.StatusOK, c.Param("*"))
})
e.GET("/*", func(c *echo.Context) error {
return c.String(http.StatusOK, c.Param("*"))
})
var testCases = []struct {
whenPath string
expectRoutePath string
expectRequestPath string
expectRequestRawPath string
}{
{
whenPath: "/api/users",
expectRoutePath: "api/users",
expectRequestPath: "/users",
expectRequestRawPath: "",
},
{
whenPath: "/js/main.js",
expectRoutePath: "js/main.js",
expectRequestPath: "/public/javascripts/main.js",
expectRequestRawPath: "",
},
{
whenPath: "/users/jack/orders/1",
expectRoutePath: "users/jack/orders/1",
expectRequestPath: "/user/jack/order/1",
expectRequestRawPath: "",
},
{ // no rewrite rule matched. already encoded URL should not be double encoded or changed in any way
whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
expectRoutePath: "user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
expectRequestPath: "/user/jill/order/T/cO4lW/t/Vp/", // this is equal to `url.Parse(tc.whenPath)` result
expectRequestRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
},
{ // just rewrite but do not touch encoding. already encoded URL should not be double encoded
whenPath: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F",
expectRoutePath: "users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F",
expectRequestPath: "/user/jill/order/T/cO4lW/t/Vp/", // this is equal to `url.Parse(tc.whenPath)` result
expectRequestRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
},
{ // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped or changed in any way when rewriting request
whenPath: "/api/new users",
expectRoutePath: "api/new users",
expectRequestPath: "/new users",
expectRequestRawPath: "",
},
}
for _, tc := range testCases {
t.Run(tc.whenPath, func(t *testing.T) {
target, _ := url.Parse(tc.whenPath)
req := httptest.NewRequest(http.MethodGet, target.String(), nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, tc.expectRoutePath, rec.Body.String())
assert.Equal(t, tc.expectRequestPath, req.URL.Path)
assert.Equal(t, tc.expectRequestRawPath, req.URL.RawPath)
})
}
}
func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) {
assert.Panics(t, func() {
RewriteWithConfig(RewriteConfig{})
})
}
func TestMustRewriteWithConfig_skipper(t *testing.T) {
var testCases = []struct {
name string
givenSkipper func(c *echo.Context) bool
whenURL string
expectURL string
expectStatus int
}{
{
name: "not skipped",
whenURL: "/old",
expectURL: "/new",
expectStatus: http.StatusOK,
},
{
name: "skipped",
givenSkipper: func(c *echo.Context) bool {
return true
},
whenURL: "/old",
expectURL: "/old",
expectStatus: http.StatusNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Pre(RewriteWithConfig(
RewriteConfig{
Skipper: tc.givenSkipper,
Rules: map[string]string{"/old": "/new"}},
))
e.GET("/new", func(c *echo.Context) error {
return c.NoContent(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectURL, req.URL.EscapedPath())
assert.Equal(t, tc.expectStatus, rec.Code)
})
}
}
// Issue #1086
func TestEchoRewritePreMiddleware(t *testing.T) {
e := echo.New()
// Rewrite old url to new one
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{"/old": "/new"}}),
)
// Route
e.Add(http.MethodGet, "/new", func(c *echo.Context) error {
return c.NoContent(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/old", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/new", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
}
// Issue #1143
func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
e := echo.New()
// middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{
"/api/*/mgmt/proj/*/agt": "/api/$1/hosts/$2",
"/api/*/mgmt/proj": "/api/$1/eng",
},
}))
e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c *echo.Context) error {
return c.String(http.StatusOK, "hosts")
})
e.Add(http.MethodGet, "/api/:version/eng", func(c *echo.Context) error {
return c.String(http.StatusOK, "eng")
})
for i := 0; i < 100; i++ {
req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, "/api/v1/hosts/test", req.URL.EscapedPath())
assert.Equal(t, http.StatusOK, rec.Code)
defer rec.Result().Body.Close()
bodyBytes, _ := io.ReadAll(rec.Result().Body)
assert.Equal(t, "hosts", string(bodyBytes))
}
}
// Issue #1573
func TestEchoRewriteWithCaret(t *testing.T) {
e := echo.New()
e.Pre(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{
"^/abc/*": "/v1/abc/$1",
},
}))
rec := httptest.NewRecorder()
var req *http.Request
req = httptest.NewRequest(http.MethodGet, "/abc/test", nil)
e.ServeHTTP(rec, req)
assert.Equal(t, "/v1/abc/test", req.URL.Path)
req = httptest.NewRequest(http.MethodGet, "/v1/abc/test", nil)
e.ServeHTTP(rec, req)
assert.Equal(t, "/v1/abc/test", req.URL.Path)
req = httptest.NewRequest(http.MethodGet, "/v2/abc/test", nil)
e.ServeHTTP(rec, req)
assert.Equal(t, "/v2/abc/test", req.URL.Path)
}
// Verify regex used with rewrite
func TestEchoRewriteWithRegexRules(t *testing.T) {
e := echo.New()
e.Pre(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{
"^/a/*": "/v1/$1",
"^/b/*/c/*": "/v2/$2/$1",
"^/c/*/*": "/v3/$2",
},
RegexRules: map[*regexp.Regexp]string{
regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1",
regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1",
},
}))
var rec *httptest.ResponseRecorder
var req *http.Request
testCases := []struct {
requestPath string
expectPath string
}{
{"/unmatched", "/unmatched"},
{"/a/test", "/v1/test"},
{"/b/foo/c/bar/baz", "/v2/bar/baz/foo"},
{"/c/ignore/test", "/v3/test"},
{"/c/ignore1/test/this", "/v3/test/this"},
{"/x/ignore/test", "/v4/test"},
{"/y/foo/bar", "/v5/bar/foo"},
}
for _, tc := range testCases {
t.Run(tc.requestPath, func(t *testing.T) {
req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expectPath, req.URL.EscapedPath())
})
}
}
// Ensure correct escaping as defined in replacement (issue #1798)
func TestEchoRewriteReplacementEscaping(t *testing.T) {
e := echo.New()
// NOTE: these are incorrect regexps as they do not factor in that URI we are replacing could contain ? (query) and # (fragment) parts
// so in reality they append query and fragment part as `$1` matches everything after that prefix
e.Pre(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{
"^/a/*": "/$1?query=param",
"^/b/*": "/$1;part#one",
},
RegexRules: map[*regexp.Regexp]string{
regexp.MustCompile("^/x/(.*)"): "/$1?query=param",
regexp.MustCompile("^/y/(.*)"): "/$1;part#one",
regexp.MustCompile("^/z/(.*)"): "/$1?test=1#escaped%20test",
},
}))
var rec *httptest.ResponseRecorder
var req *http.Request
testCases := []struct {
requestPath string
expect string
}{
{"/unmatched", "/unmatched"},
{"/a/test", "/test?query=param"},
{"/b/foo/bar", "/foo/bar;part#one"},
{"/x/test", "/test?query=param"},
{"/y/foo/bar", "/foo/bar;part#one"},
{"/z/foo/b%20ar", "/foo/b%20ar?test=1#escaped%20test"},
{"/z/foo/b%20ar?nope=1#yes", "/foo/b%20ar?nope=1#yes?test=1%23escaped%20test"}, // example of appending
}
for _, tc := range testCases {
t.Run(tc.requestPath, func(t *testing.T) {
req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, tc.expect, req.URL.String())
})
}
}
================================================
FILE: middleware/secure.go
================================================
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"fmt"
"github.com/labstack/echo/v5"
)
// SecureConfig defines the config for Secure middleware.
type SecureConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// XSSProtection provides protection against cross-site scripting attack (XSS)
// by setting the `X-XSS-Protection` header.
// Optional. Default value "1; mode=block".
XSSProtection string
// ContentTypeNosniff provides protection against overriding Content-Type
// header by setting the `X-Content-Type-Options` header.
// Optional. Default value "nosniff".
ContentTypeNosniff string
// XFrameOptions can be used to indicate whether or not a browser should
// be allowed to render a page in a ,