Repository: gofiber/fiber
Branch: main
Commit: d6eb4973338e
Files: 382
Total size: 3.6 MB
Directory structure:
gitextract_9m90gl79/
├── .cspell.json
├── .editorconfig
├── .gitattributes
├── .github/
│ ├── .editorconfig
│ ├── .hound.yml
│ ├── CODEOWNERS
│ ├── CODE_OF_CONDUCT.md
│ ├── CONTRIBUTING.md
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug-report.yaml
│ │ ├── config.yml
│ │ ├── feature-request.yaml
│ │ ├── maintenance-task.yaml
│ │ └── question.yaml
│ ├── README.md
│ ├── SECURITY.md
│ ├── codecov.yml
│ ├── config.yml
│ ├── copilot-instructions.md
│ ├── copilot-setup-steps.yml
│ ├── dependabot.yml
│ ├── index.html
│ ├── labeler.yml
│ ├── pull_request_template.md
│ ├── release-drafter.yml
│ ├── release.yml
│ ├── scripts/
│ │ └── sync_docs.sh
│ ├── testdata/
│ │ ├── ca-chain.cert.pem
│ │ ├── fs/
│ │ │ ├── css/
│ │ │ │ ├── style.css
│ │ │ │ └── test/
│ │ │ │ └── style2.css
│ │ │ ├── img/
│ │ │ │ ├── fiberpng
│ │ │ │ └── fiberpng.notvalidext
│ │ │ └── index.html
│ │ ├── hello_world.tmpl
│ │ ├── index.html
│ │ ├── index.tmpl
│ │ ├── main.tmpl
│ │ ├── ssl.key
│ │ ├── ssl.pem
│ │ ├── template-invalid.html
│ │ ├── template.tmpl
│ │ └── testRoutes.json
│ ├── testdata2/
│ │ └── bruh.tmpl
│ ├── testdata3/
│ │ └── hello_world.tmpl
│ └── workflows/
│ ├── auto-labeler.yml
│ ├── benchmark.yml
│ ├── codeql-analysis.yml
│ ├── dependabot_automerge.yml
│ ├── linter.yml
│ ├── manual-dependabot.yml
│ ├── markdown.yml
│ ├── modernize.yml
│ ├── move-closed-milestone-items.yml
│ ├── release-drafter.yml
│ ├── spell-check.yml
│ ├── sync-docs.yml
│ ├── test.yml
│ ├── v3-label-automation.yml
│ └── vulncheck.yml
├── .gitignore
├── .golangci.yml
├── .markdownlint.yml
├── AGENTS.md
├── LICENSE
├── Makefile
├── adapter.go
├── adapter_test.go
├── addon/
│ └── retry/
│ ├── README.md
│ ├── config.go
│ ├── config_test.go
│ ├── exponential_backoff.go
│ └── exponential_backoff_test.go
├── app.go
├── app_integration_test.go
├── app_test.go
├── bind.go
├── bind_test.go
├── binder/
│ ├── README.md
│ ├── binder.go
│ ├── binder_test.go
│ ├── cbor.go
│ ├── cbor_test.go
│ ├── cookie.go
│ ├── cookie_test.go
│ ├── form.go
│ ├── form_test.go
│ ├── header.go
│ ├── header_test.go
│ ├── json.go
│ ├── json_test.go
│ ├── mapping.go
│ ├── mapping_test.go
│ ├── msgpack.go
│ ├── msgpack_test.go
│ ├── query.go
│ ├── query_test.go
│ ├── resp_header.go
│ ├── resp_header_test.go
│ ├── uri.go
│ ├── uri_test.go
│ ├── xml.go
│ └── xml_test.go
├── client/
│ ├── README.md
│ ├── client.go
│ ├── client_test.go
│ ├── cookiejar.go
│ ├── cookiejar_test.go
│ ├── core.go
│ ├── core_test.go
│ ├── errors.go
│ ├── helper_test.go
│ ├── hooks.go
│ ├── hooks_test.go
│ ├── request.go
│ ├── request_bench_test.go
│ ├── request_test.go
│ ├── response.go
│ ├── response_test.go
│ ├── transport.go
│ └── transport_test.go
├── color.go
├── constants.go
├── ctx.go
├── ctx_interface.go
├── ctx_interface_gen.go
├── ctx_test.go
├── docs/
│ ├── addon/
│ │ ├── _category_.json
│ │ └── retry.md
│ ├── api/
│ │ ├── _category_.json
│ │ ├── app.md
│ │ ├── bind.md
│ │ ├── constants.md
│ │ ├── ctx.md
│ │ ├── fiber.md
│ │ ├── hooks.md
│ │ ├── log.md
│ │ ├── redirect.md
│ │ ├── services.md
│ │ └── state.md
│ ├── client/
│ │ ├── _category_.json
│ │ ├── examples.md
│ │ ├── hooks.md
│ │ ├── request.md
│ │ ├── response.md
│ │ └── rest.md
│ ├── extra/
│ │ ├── _category_.json
│ │ ├── benchmarks.md
│ │ ├── faq.md
│ │ ├── internal.md
│ │ └── learning-resources.md
│ ├── guide/
│ │ ├── _category_.json
│ │ ├── advance-format.md
│ │ ├── context.md
│ │ ├── error-handling.md
│ │ ├── extractors.md
│ │ ├── faster-fiber.md
│ │ ├── grouping.md
│ │ ├── reverse-proxy.md
│ │ ├── routing.md
│ │ ├── templates.md
│ │ ├── utils.md
│ │ └── validation.md
│ ├── intro.md
│ ├── middleware/
│ │ ├── _category_.json
│ │ ├── adaptor.md
│ │ ├── basicauth.md
│ │ ├── cache.md
│ │ ├── compress.md
│ │ ├── cors.md
│ │ ├── csrf.md
│ │ ├── earlydata.md
│ │ ├── encryptcookie.md
│ │ ├── envvar.md
│ │ ├── etag.md
│ │ ├── expvar.md
│ │ ├── favicon.md
│ │ ├── healthcheck.md
│ │ ├── helmet.md
│ │ ├── idempotency.md
│ │ ├── keyauth.md
│ │ ├── limiter.md
│ │ ├── logger.md
│ │ ├── paginate.md
│ │ ├── pprof.md
│ │ ├── proxy.md
│ │ ├── recover.md
│ │ ├── redirect.md
│ │ ├── requestid.md
│ │ ├── responsetime.md
│ │ ├── rewrite.md
│ │ ├── session.md
│ │ ├── skip.md
│ │ ├── static.md
│ │ └── timeout.md
│ ├── partials/
│ │ └── routing/
│ │ └── handler.md
│ └── whats_new.md
├── error.go
├── error_test.go
├── errors_internal.go
├── extractors/
│ ├── README.md
│ ├── extractors.go
│ └── extractors_test.go
├── go.mod
├── go.sum
├── group.go
├── helpers.go
├── helpers_fuzz_test.go
├── helpers_test.go
├── hooks.go
├── hooks_test.go
├── internal/
│ ├── memory/
│ │ ├── memory.go
│ │ └── memory_test.go
│ ├── storage/
│ │ └── memory/
│ │ ├── config.go
│ │ ├── memory.go
│ │ └── memory_test.go
│ └── tlstest/
│ └── tls.go
├── listen.go
├── listen_test.go
├── log/
│ ├── default.go
│ ├── default_test.go
│ ├── fiberlog.go
│ ├── fiberlog_test.go
│ └── log.go
├── middleware/
│ ├── adaptor/
│ │ ├── adaptor.go
│ │ └── adaptor_test.go
│ ├── basicauth/
│ │ ├── basicauth.go
│ │ ├── basicauth_test.go
│ │ └── config.go
│ ├── cache/
│ │ ├── cache.go
│ │ ├── cache_test.go
│ │ ├── config.go
│ │ ├── heap.go
│ │ ├── manager.go
│ │ ├── manager_msgp.go
│ │ ├── manager_msgp_test.go
│ │ └── manager_test.go
│ ├── compress/
│ │ ├── compress.go
│ │ ├── compress_test.go
│ │ └── config.go
│ ├── cors/
│ │ ├── config.go
│ │ ├── cors.go
│ │ ├── cors_test.go
│ │ ├── utils.go
│ │ └── utils_test.go
│ ├── csrf/
│ │ ├── config.go
│ │ ├── config_test.go
│ │ ├── csrf.go
│ │ ├── csrf_test.go
│ │ ├── helpers.go
│ │ ├── helpers_test.go
│ │ ├── session_manager.go
│ │ ├── storage_manager.go
│ │ ├── storage_manager_msgp.go
│ │ ├── storage_manager_msgp_test.go
│ │ └── token.go
│ ├── earlydata/
│ │ ├── config.go
│ │ ├── earlydata.go
│ │ └── earlydata_test.go
│ ├── encryptcookie/
│ │ ├── config.go
│ │ ├── config_test.go
│ │ ├── encryptcookie.go
│ │ ├── encryptcookie_test.go
│ │ └── utils.go
│ ├── envvar/
│ │ ├── config.go
│ │ ├── envvar.go
│ │ └── envvar_test.go
│ ├── etag/
│ │ ├── config.go
│ │ ├── etag.go
│ │ └── etag_test.go
│ ├── expvar/
│ │ ├── config.go
│ │ ├── expvar.go
│ │ └── expvar_test.go
│ ├── favicon/
│ │ ├── config.go
│ │ ├── favicon.go
│ │ └── favicon_test.go
│ ├── healthcheck/
│ │ ├── config.go
│ │ ├── healthcheck.go
│ │ └── healthcheck_test.go
│ ├── helmet/
│ │ ├── config.go
│ │ ├── helmet.go
│ │ └── helmet_test.go
│ ├── idempotency/
│ │ ├── config.go
│ │ ├── idempotency.go
│ │ ├── idempotency_test.go
│ │ ├── locker.go
│ │ ├── locker_test.go
│ │ ├── response.go
│ │ ├── response_msgp.go
│ │ ├── response_msgp_test.go
│ │ └── stub_test.go
│ ├── keyauth/
│ │ ├── config.go
│ │ ├── config_test.go
│ │ ├── keyauth.go
│ │ └── keyauth_test.go
│ ├── limiter/
│ │ ├── config.go
│ │ ├── limiter.go
│ │ ├── limiter_fixed.go
│ │ ├── limiter_sliding.go
│ │ ├── limiter_test.go
│ │ ├── manager.go
│ │ ├── manager_msgp.go
│ │ └── manager_msgp_test.go
│ ├── logger/
│ │ ├── config.go
│ │ ├── data.go
│ │ ├── default_logger.go
│ │ ├── errors.go
│ │ ├── format.go
│ │ ├── logger.go
│ │ ├── logger_test.go
│ │ ├── tags.go
│ │ ├── template_chain.go
│ │ └── utils.go
│ ├── paginate/
│ │ ├── config.go
│ │ ├── page_info.go
│ │ ├── paginate.go
│ │ └── paginate_test.go
│ ├── pprof/
│ │ ├── config.go
│ │ ├── pprof.go
│ │ └── pprof_test.go
│ ├── proxy/
│ │ ├── config.go
│ │ ├── proxy.go
│ │ └── proxy_test.go
│ ├── recover/
│ │ ├── config.go
│ │ ├── recover.go
│ │ └── recover_test.go
│ ├── redirect/
│ │ ├── config.go
│ │ ├── redirect.go
│ │ └── redirect_test.go
│ ├── requestid/
│ │ ├── config.go
│ │ ├── requestid.go
│ │ └── requestid_test.go
│ ├── responsetime/
│ │ ├── config.go
│ │ ├── responsetime.go
│ │ └── responsetime_test.go
│ ├── rewrite/
│ │ ├── config.go
│ │ ├── rewrite.go
│ │ └── rewrite_test.go
│ ├── session/
│ │ ├── config.go
│ │ ├── config_test.go
│ │ ├── data.go
│ │ ├── data_msgp.go
│ │ ├── data_msgp_test.go
│ │ ├── data_test.go
│ │ ├── middleware.go
│ │ ├── middleware_test.go
│ │ ├── session.go
│ │ ├── session_test.go
│ │ ├── store.go
│ │ └── store_test.go
│ ├── skip/
│ │ ├── skip.go
│ │ └── skip_test.go
│ ├── static/
│ │ ├── config.go
│ │ ├── static.go
│ │ └── static_test.go
│ └── timeout/
│ ├── config.go
│ ├── timeout.go
│ └── timeout_test.go
├── mount.go
├── mount_test.go
├── path.go
├── path_test.go
├── path_testcases_test.go
├── prefork.go
├── prefork_test.go
├── readonly.go
├── readonly_strict.go
├── redirect.go
├── redirect_msgp.go
├── redirect_msgp_test.go
├── redirect_test.go
├── register.go
├── req.go
├── req_interface_gen.go
├── res.go
├── res_interface_gen.go
├── router.go
├── router_test.go
├── services.go
├── services_test.go
├── state.go
├── state_test.go
└── storage_interface.go
================================================
FILE CONTENTS
================================================
================================================
FILE: .cspell.json
================================================
{
"version": "0.2",
"language": "en, en-gb, en-us",
"useGitignore": true,
"caseSensitive": false,
"import": [
"@cspell/dict-en_us/cspell-ext.json",
"@cspell/dict-en-gb/cspell-ext.json",
"@cspell/dict-software-terms/cspell-ext.json",
"@cspell/dict-golang/cspell-ext.json",
"@cspell/dict-fullstack/cspell-ext.json",
"@cspell/dict-docker/cspell-ext.json",
"@cspell/dict-k8s/cspell-ext.json",
"@cspell/dict-node/cspell-ext.json",
"@cspell/dict-npm/cspell-ext.json",
"@cspell/dict-typescript/cspell-ext.json",
"@cspell/dict-html/cspell-ext.json",
"@cspell/dict-css/cspell-ext.json",
"@cspell/dict-shell/cspell-ext.json",
"@cspell/dict-python/cspell-ext.json",
"@cspell/dict-redis/cspell-ext.json",
"@cspell/dict-sql/cspell-ext.json",
"@cspell/dict-filetypes/cspell-ext.json",
"@cspell/dict-companies/cspell-ext.json",
"@cspell/dict-markdown/cspell-ext.json",
"@cspell/dict-en-common-misspellings/cspell-ext.json",
"@cspell/dict-people-names/cspell-ext.json"
],
"dictionaries": [
"en_us",
"en-gb",
"softwareTerms",
"web-services",
"networking-terms",
"software-term-suggestions",
"software-services",
"software-terms",
"software-tools",
"coding-compound-terms",
"golang",
"fullstack",
"docker",
"k8s",
"node",
"npm",
"typescript",
"html",
"css",
"shell",
"python",
"redis",
"sql",
"filetypes",
"companies",
"markdown",
"en-common-misspellings",
"people-names",
"data-science",
"data-science-models",
"data-science-tools"
],
"ignorePaths": [
".git",
"node_modules",
"vendor",
"internal",
".github",
"**/*.svg",
"**/*.png",
"**/*.jpg",
"**/*.jpeg",
"**/*.gif",
"**/*.ico",
"**/*.lock",
"**/*_gen.go",
"**/*_msgp.go",
"**/*_msgp_test.go",
"go.mod",
"go.sum",
".golangci.yml",
".markdownlint.yml",
"AGENTS.md"
]
}
================================================
FILE: .editorconfig
================================================
; This file is for unifying the coding style for different editors and IDEs.
; More information at http://editorconfig.org
; This style originates from https://github.com/fewagency/best-practices
root = true
[*]
charset = utf-8
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
[*.go]
indent_style = tab
indent_size = 4
tab_width = 4
[Makefile]
indent_style = tab
[*.{yml,yaml,json,md}]
indent_style = space
indent_size = 2
================================================
FILE: .gitattributes
================================================
# Handle line endings automatically for files detected as text
# and leave all files detected as binary untouched.
* text=auto eol=lf
# Force batch scripts to always use CRLF line endings so that if a repo is accessed
# in Windows via a file share from Linux, the scripts will work.
*.{cmd,[cC][mM][dD]} text eol=crlf
*.{bat,[bB][aA][tT]} text eol=crlf
# Force bash scripts to always use LF line endings so that if a repo is accessed
# in Unix via a file share from Windows, the scripts will work.
*.sh text eol=lf
================================================
FILE: .github/.editorconfig
================================================
; https://editorconfig.org/
root = true
[*]
insert_final_newline = true
charset = utf-8
trim_trailing_whitespace = true
indent_style = space
indent_size = 2
[{Makefile,go.mod,go.sum,*.go,.gitmodules}]
indent_style = tab
indent_size = 8
[*.md]
indent_size = 4
trim_trailing_whitespace = false
eclint_indent_style = unset
[Dockerfile]
indent_size = 4
================================================
FILE: .github/.hound.yml
================================================
golint:
enabled: false
================================================
FILE: .github/CODEOWNERS
================================================
* @gofiber/maintainers
================================================
FILE: .github/CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our community include:
- Demonstrating empathy and kindness toward other people
- Being respectful of differing opinions, viewpoints, and experiences
- Giving and gracefully accepting constructive feedback
- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
- Focusing on what is best not just for us as individuals, but for the overall community
Examples of unacceptable behavior include:
- The use of sexualized language or imagery, and sexual attention or advances of any kind
- Trolling, insulting or derogatory comments, and personal or political attacks
- Public or private harassment
- Publishing others' private information, such as a physical or email address, without their explicit permission
- Other conduct which could reasonably be considered inappropriate in a professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.
Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at [Discord](https://gofiber.io/discord). All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of actions.
**Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0,
available at [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html](https://www.contributor-covenant.org/version/2/0/code_of_conduct.html).
Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq). Translations are available at [https://www.contributor-covenant.org/translations](https://www.contributor-covenant.org/translations).
================================================
FILE: .github/CONTRIBUTING.md
================================================
# Contributing
Before making any changes to this repository, we kindly request you to initiate discussions for proposed changes that do not yet have an associated [issue](https://github.com/gofiber/fiber/issues). Please use our [Discord](https://gofiber.io/discord) server to initiate these discussions. For [issue](https://github.com/gofiber/fiber/issues) that already exist, you may proceed with discussions using our [issue](https://github.com/gofiber/fiber/issues) tracker or any other suitable method, in consultation with the repository owners. Your collaboration is greatly appreciated.
Please note: we have a [code of conduct](https://github.com/gofiber/fiber/blob/main/.github/CODE_OF_CONDUCT.md), please follow it in all your interactions with the `Fiber` project.
## Pull Requests or Commits
Titles always we must use prefix according to below:
> 🔥 Feature, ♻️ Refactor, 🩹 Fix, 🚨 Test, 📚 Doc, 🎨 Style
- 🔥 Feature: Add flow to add person
- ♻️ Refactor: Rename file X to Y
- 🩹 Fix: Improve flow
- 🚨 Test: Validate to add a new person
- 📚 Doc: Translate to Portuguese middleware redirect
- 🎨 Style: Respected pattern Golint
All pull requests that contain a feature or fix are mandatory to have unit tests. Your PR is only to be merged if you respect this flow.
## 👍 Contribute
If you want to say **thank you** and/or support the active development of `Fiber`:
1. Add a [GitHub Star](https://github.com/gofiber/fiber/stargazers) to the project.
2. Tweet about the project [on your 𝕏 (Twitter)](https://x.com/intent/tweet?text=%F0%9F%9A%80%20Fiber%20%E2%80%94%20is%20an%20Express.js%20inspired%20web%20framework%20build%20on%20Fasthttp%20for%20%23Go%20https%3A%2F%2Fgithub.com%2Fgofiber%2Ffiber).
3. Write a review or tutorial on [Medium](https://medium.com/), [Dev.to](https://dev.to/) or personal blog.
4. Support the project by donating a [cup of coffee](https://buymeacoff.ee/fenny).
================================================
FILE: .github/FUNDING.yml
================================================
# These are supported funding model platforms
github: [gofiber] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
custom: https://github.com/sponsors/gofiber
================================================
FILE: .github/ISSUE_TEMPLATE/bug-report.yaml
================================================
name: "\U0001F41B Bug Report"
title: "\U0001F41B [Bug]: "
description: Create a bug report to help us fix it.
labels: ["☢️ Bug"]
body:
- type: markdown
id: notice
attributes:
value: |
### Notice
**This repository is not related to external or third-part Fiber modules. If you have a problem with them, open an issue under their repos. If you think the problem is related to Fiber, open the issue here.**
- Don't forget you can ask your questions in our [Discord server](https://gofiber.io/discord).
- If you have a suggestion for a Fiber feature you would like to see, open the issue with the **✏️ Feature Request** template.
- Write your issue with clear and understandable English.
- type: textarea
id: description
attributes:
label: "Bug Description"
description: "A clear and detailed description of what the bug is."
placeholder: "Explain your problem clearly and in detail."
validations:
required: true
- type: textarea
id: how-to-reproduce
attributes:
label: How to Reproduce
description: "Steps to reproduce the behavior and what should be observed in the end."
placeholder: "Tell us step by step how we can replicate your problem and what we should see in the end."
value: |
Steps to reproduce the behavior:
1. Go to '....'
2. Click on '....'
3. Do '....'
4. See '....'
validations:
required: true
- type: textarea
id: expected-behavior
attributes:
label: Expected Behavior
description: "A clear and detailed description of what you think should happen."
placeholder: "Tell us what Fiber should normally do."
validations:
required: true
- type: input
id: version
attributes:
label: "Fiber Version"
description: "Some bugs may be fixed in future Fiber releases, so we have to know your Fiber version."
placeholder: "Write your Fiber version. (v2.33.0, v2.34.1...)"
validations:
required: true
- type: textarea
id: snippet
attributes:
label: "Code Snippet (optional)"
description: "For some issues, we need to know some parts of your code."
placeholder: "Share a code snippet that you think is related to the issue."
render: go
value: |
package main
import "github.com/gofiber/fiber/v3"
import "log"
func main() {
app := fiber.New()
// Steps to reproduce
log.Fatal(app.Listen(":3000"))
}
- type: checkboxes
id: terms
attributes:
label: "Checklist:"
description: "By submitting this issue, you confirm that:"
options:
- label: "I agree to follow Fiber's [Code of Conduct](https://github.com/gofiber/fiber/blob/main/.github/CODE_OF_CONDUCT.md)."
required: true
- label: "I have checked for existing issues that describe my problem prior to opening this one."
required: true
- label: "I understand that improperly formatted bug reports may be closed without explanation."
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/config.yml
================================================
blank_issues_enabled: false
================================================
FILE: .github/ISSUE_TEMPLATE/feature-request.yaml
================================================
name: "📝 Feature Proposal"
title: "📝 [Proposal]: "
description: Propose a feature or improvement for Fiber.
labels: ["📝 Proposal", "✏️ Feature", "v3"]
body:
- type: markdown
id: notice
attributes:
value: |
### Notice
- For questions, join our [Discord server](https://gofiber.io/discord).
- Please write in clear, understandable English.
- Ensure your proposal aligns with Express design principles and HTTP RFC standards.
- Describe features expected to remain stable and not require changes in the foreseeable future.
- type: textarea
id: description
attributes:
label: "Feature Proposal Description"
description: "A clear and detailed description of the feature you are proposing for Fiber v3. How should it work, and what API endpoints and methods would it involve?"
placeholder: "Describe your feature proposal clearly and in detail, including API endpoints and methods."
validations:
required: true
- type: textarea
id: express-alignment
attributes:
label: "Alignment with Express API"
description: "Explain how your proposal aligns with the design and API of Express.js. Provide comparative examples if possible."
placeholder: "Outline how the feature aligns with Express.js design principles and API standards."
validations:
required: true
- type: textarea
id: standards-compliance
attributes:
label: "HTTP RFC Standards Compliance"
description: "Confirm that the feature complies with HTTP RFC standards, and describe any relevant aspects."
placeholder: "Detail how the feature adheres to HTTP RFC standards."
validations:
required: true
- type: textarea
id: stability
attributes:
label: "API Stability"
description: "Discuss the expected stability of the feature and its API. How do you ensure that it will not require changes or deprecations in the near future?"
placeholder: "Describe measures taken to ensure the feature's API stability over time."
validations:
required: true
- type: textarea
id: examples
attributes:
label: "Feature Examples"
description: "Provide concrete examples and code snippets to illustrate how the proposed feature should function."
placeholder: "Share code snippets that exemplify the proposed feature and its usage."
render: go
validations:
required: true
- type: checkboxes
id: terms
attributes:
label: "Checklist:"
description: "By submitting this issue, you confirm that:"
options:
- label: "I agree to follow Fiber's [Code of Conduct](https://github.com/gofiber/fiber/blob/main/.github/CODE_OF_CONDUCT.md)."
required: true
- label: "I have searched for existing issues that describe my proposal before opening this one."
required: true
- label: "I understand that a proposal that does not meet these guidelines may be closed without explanation."
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/maintenance-task.yaml
================================================
name: "🧹 Maintenance Task"
title: "🧹 [Maintenance]: "
description: Describe a maintenance task for the v3 of the Fiber project.
labels: ["🧹 Updates", "v3"]
body:
- type: markdown
id: notice
attributes:
value: |
### Notice
- Before submitting a maintenance task, please check if a similar task has already been filed.
- Clearly outline the purpose of the maintenance task and its impact on the project.
- Use clear and understandable English.
- type: textarea
id: task-description
attributes:
label: "Maintenance Task Description"
description: "Provide a detailed description of the maintenance task. Include any specific areas of the codebase that require attention, and the desired outcomes of this task."
placeholder: "Detail the maintenance task, specifying what needs to be done and why it is necessary."
validations:
required: true
- type: textarea
id: impact
attributes:
label: "Impact on the Project"
description: "Explain the impact this maintenance will have on the project. Include benefits and potential risks if applicable."
placeholder: "Describe how completing this task will benefit the project, or the risks of not addressing it."
validations:
required: false
- type: textarea
id: additional-context
attributes:
label: "Additional Context (optional)"
description: "Any additional information or context regarding the maintenance task that might be helpful."
placeholder: "Provide any additional information that may be relevant to the task at hand."
validations:
required: false
- type: checkboxes
id: terms
attributes:
label: "Checklist:"
description: "Please confirm the following:"
options:
- label: "I have confirmed that this maintenance task is currently not being addressed."
required: true
- label: "I understand that this task will be evaluated by the maintainers and prioritized accordingly."
required: true
- label: "I am available to provide further information if needed."
required: true
================================================
FILE: .github/ISSUE_TEMPLATE/question.yaml
================================================
name: "🤔 Question"
title: "\U0001F917 [Question]: "
description: Ask a question so we can help you easily.
labels: ["🤔 Question"]
body:
- type: markdown
id: notice
attributes:
value: |
### Notice
- Don't forget you can ask your questions in our [Discord server](https://gofiber.io/discord).
- If you think this is just a bug, open the issue with the **☢️ Bug Report** template.
- If you have a suggestion for a Fiber feature you would like to see, open the issue with the **✏️ Feature Request** template.
- Write your issue with clear and understandable English.
- type: textarea
id: description
attributes:
label: "Question Description"
description: "A clear and detailed description of the question."
placeholder: "Explain your question clearly, and in detail."
validations:
required: true
- type: textarea
id: snippet
attributes:
label: "Code Snippet (optional)"
description: "Code snippet may be really helpful to describe some features."
placeholder: "Share a code snippet to explain the feature better."
render: go
value: |
package main
import "github.com/gofiber/fiber/v3"
import "log"
func main() {
app := fiber.New()
// An example to describe the question
log.Fatal(app.Listen(":3000"))
}
- type: checkboxes
id: terms
attributes:
label: "Checklist:"
description: "By submitting this issue, you confirm that:"
options:
- label: "I agree to follow Fiber's [Code of Conduct](https://github.com/gofiber/fiber/blob/main/.github/CODE_OF_CONDUCT.md)."
required: true
- label: "I have checked for existing issues that describe my questions prior to opening this one."
required: true
- label: "I understand that improperly formatted questions may be closed without explanation."
required: true
================================================
FILE: .github/README.md
================================================
Fiber is an Express inspired web framework built on top of Fasthttp, the fastest HTTP engine for Go. Designed to ease things up for fast development with zero memory allocation and performance in mind.
---
## ⚙️ Installation
Fiber requires **Go version `1.25` or higher** to run. If you need to install or upgrade Go, visit the [official Go download page](https://go.dev/dl/). To start setting up your project, create a new directory for your project and navigate into it. Then, initialize your project with Go modules by executing the following command in your terminal:
```bash
go mod init github.com/your/repo
```
To learn more about Go modules and how they work, you can check out the [Using Go Modules](https://go.dev/blog/using-go-modules) blog post.
After setting up your project, you can install Fiber with the `go get` command:
```bash
go get -u github.com/gofiber/fiber/v3
```
This command fetches the Fiber package and adds it to your project's dependencies, allowing you to start building your web applications with Fiber.
## ⚡️ Quickstart
Getting started with Fiber is easy. Here's a basic example to create a simple web server that responds with "Hello, World 👋!" on the root path. This example demonstrates initializing a new Fiber app, setting up a route, and starting the server.
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
// Initialize a new Fiber app
app := fiber.New()
// Define a route for the GET method on the root path '/'
app.Get("/", func(c fiber.Ctx) error {
// Send a string response to the client
return c.SendString("Hello, World 👋!")
})
// Start the server on port 3000
log.Fatal(app.Listen(":3000"))
}
```
This simple server is easy to set up and run. It introduces the core concepts of Fiber: app initialization, route definition, and starting the server. Just run this Go program, and visit `http://localhost:3000` in your browser to see the message.
## Zero Allocation
Fiber is optimized for **high-performance**, meaning values returned from **fiber.Ctx** are **not** immutable by default and **will** be re-used across requests. As a rule of thumb, you **must** only use context values within the handler and **must not** keep any references. Once you return from the handler, any values obtained from the context will be re-used in future requests. Visit our [documentation](https://docs.gofiber.io/#zero-allocation) to learn more.
## 🤖 Benchmarks
These tests are performed by [TechEmpower](https://www.techempower.com/benchmarks/#section=data-r19&hw=ph&test=plaintext). If you want to see all the results, please visit our [Wiki](https://docs.gofiber.io/extra/benchmarks).
## 🎯 Features
- Robust [Routing](https://docs.gofiber.io/guide/routing)
- Serve [Static Files](https://docs.gofiber.io/api/app#static)
- Extreme [Performance](https://docs.gofiber.io/extra/benchmarks)
- [Low Memory](https://docs.gofiber.io/extra/benchmarks) footprint
- [API Endpoints](https://docs.gofiber.io/api/ctx)
- [Middleware](https://docs.gofiber.io/category/-middleware) & [Next](https://docs.gofiber.io/api/ctx#next) support
- [Rapid](https://dev.to/koddr/welcome-to-fiber-an-express-js-styled-fastest-web-framework-written-with-on-golang-497) server-side programming
- [Template Engines](https://github.com/gofiber/template)
- [WebSocket Support](https://github.com/gofiber/contrib/tree/main/websocket)
- [Socket.io Support](https://github.com/gofiber/contrib/tree/main/socketio)
- [Server-Sent Events](https://github.com/gofiber/recipes/tree/master/sse)
- [Rate Limiter](https://docs.gofiber.io/api/middleware/limiter)
- And much more, [explore Fiber](https://docs.gofiber.io/)
## 💡 Philosophy
New gophers that make the switch from [Node.js](https://nodejs.org/en/about/) to [Go](https://go.dev/doc/) are dealing with a learning curve before they can start building their web applications or microservices. Fiber, as a **web framework**, was created with the idea of **minimalism** and follows the **UNIX way**, so that new gophers can quickly enter the world of Go with a warm and trusted welcome.
Fiber is **inspired** by Express, the most popular web framework on the Internet. We combined the **ease** of Express and **raw performance** of Go. If you have ever implemented a web application in Node.js (_using Express or similar_), then many methods and principles will seem **very common** to you.
We **listen** to our users in [issues](https://github.com/gofiber/fiber/issues), Discord [channel](https://gofiber.io/discord) _and all over the Internet_ to create a **fast**, **flexible** and **friendly** Go web framework for **any** task, **deadline** and developer **skill**! Just like Express does in the JavaScript world.
## ⚠️ Limitations
- Due to Fiber's usage of unsafe, the library may not always be compatible with the latest Go version. Fiber v3 has been tested with Go version 1.25 or higher.
- Fiber automatically adapts common `net/http` handler shapes when you register them on the router, and you can still use the [adaptor middleware](https://docs.gofiber.io/next/middleware/adaptor/) when you need to bridge entire apps or `net/http` middleware.
### net/http compatibility
Fiber can run side by side with the standard library. The router accepts existing `net/http` handlers directly and even works with native `fasthttp.RequestHandler` callbacks, so you can plug in legacy endpoints without wrapping them manually:
```go
package main
import (
"log"
"net/http"
"github.com/gofiber/fiber/v3"
)
func main() {
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := w.Write([]byte("served by net/http")); err != nil {
panic(err)
}
})
app := fiber.New()
app.Get("/", httpHandler)
// Start the server on port 3000
log.Fatal(app.Listen(":3000"))
}
```
When you need to convert entire applications or re-use `net/http` middleware chains, rely on the [adaptor middleware](https://docs.gofiber.io/next/middleware/adaptor/). It converts handlers and middlewares in both directions and even lets you mount a Fiber app in a `net/http` server.
### Express-style handlers
Fiber also adapts Express-style callbacks that operate on the lightweight `fiber.Req` and `fiber.Res` helper interfaces. This lets you port middleware and route handlers from Express-inspired codebases while keeping Fiber's router features:
```go
// Request/response handlers (2-argument)
app.Get("/", func(req fiber.Req, res fiber.Res) error {
return res.SendString("Hello from Express-style handlers!")
})
// Middleware with an error-returning next callback (3-argument)
app.Use(func(req fiber.Req, res fiber.Res, next func() error) error {
if req.IP() == "192.168.1.254" {
return res.SendStatus(fiber.StatusForbidden)
}
return next()
})
// Middleware with a no-arg next callback (3-argument)
app.Use(func(req fiber.Req, res fiber.Res, next func()) {
if req.Get("X-Skip") == "true" {
return // stop the chain without calling next
}
next()
})
```
> **Note:** Adapted `net/http` handlers continue to operate with the standard-library semantics. They don't get access to `fiber.Ctx` features and incur the overhead of the compatibility layer, so native `fiber.Handler` callbacks still provide the best performance.
## 👀 Examples
Listed below are some of the common examples. If you want to see more code examples, please visit our [Recipes repository](https://github.com/gofiber/recipes) or visit our hosted [API documentation](https://docs.gofiber.io).
### 📖 [**Basic Routing**](https://docs.gofiber.io/#basic-routing)
```go title="Example"
package main
import (
"fmt"
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
// GET /api/register
app.Get("/api/*", func(c fiber.Ctx) error {
msg := fmt.Sprintf("✋ %s", c.Params("*"))
return c.SendString(msg) // => ✋ register
})
// GET /flights/LAX-SFO
app.Get("/flights/:from-:to", func(c fiber.Ctx) error {
msg := fmt.Sprintf("💸 From: %s, To: %s", c.Params("from"), c.Params("to"))
return c.SendString(msg) // => 💸 From: LAX, To: SFO
})
// GET /dictionary.txt
app.Get("/:file.:ext", func(c fiber.Ctx) error {
msg := fmt.Sprintf("📃 %s.%s", c.Params("file"), c.Params("ext"))
return c.SendString(msg) // => 📃 dictionary.txt
})
// GET /john/75
app.Get("/:name/:age/:gender?", func(c fiber.Ctx) error {
msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age"))
return c.SendString(msg) // => 👴 john is 75 years old
})
// GET /john
app.Get("/:name", func(c fiber.Ctx) error {
msg := fmt.Sprintf("Hello, %s 👋!", c.Params("name"))
return c.SendString(msg) // => Hello john 👋!
})
log.Fatal(app.Listen(":3000"))
}
```
#### 📖 [**Route Naming**](https://docs.gofiber.io/api/app#name)
```go title="Example"
package main
import (
"encoding/json"
"fmt"
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Get("/api/*", func(c fiber.Ctx) error {
msg := fmt.Sprintf("✋ %s", c.Params("*"))
return c.SendString(msg) // => ✋ register
}).Name("api")
route := app.GetRoute("api")
data, _ := json.MarshalIndent(route, "", " ")
fmt.Println(string(data))
// Prints:
// {
// "method": "GET",
// "name": "api",
// "path": "/api/*",
// "params": [
// "*1"
// ]
// }
log.Fatal(app.Listen(":3000"))
}
```
#### 📖 [**Serving Static Files**](https://docs.gofiber.io/api/app#static)
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/static"
)
func main() {
app := fiber.New()
// Serve static files from the "./public" directory
app.Get("/*", static.New("./public"))
// => http://localhost:3000/js/script.js
// => http://localhost:3000/css/style.css
app.Get("/prefix*", static.New("./public"))
// => http://localhost:3000/prefix/js/script.js
// => http://localhost:3000/prefix/css/style.css
// Serve a single file for any unmatched routes
app.Get("*", static.New("./public/index.html"))
// => http://localhost:3000/any/path/shows/index.html
log.Fatal(app.Listen(":3000"))
}
```
#### 📖 [**Middleware & Next**](https://docs.gofiber.io/api/ctx#next)
```go title="Example"
package main
import (
"fmt"
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
// Middleware that matches any route
app.Use(func(c fiber.Ctx) error {
fmt.Println("🥇 First handler")
return c.Next()
})
// Middleware that matches all routes starting with /api
app.Use("/api", func(c fiber.Ctx) error {
fmt.Println("🥈 Second handler")
return c.Next()
})
// GET /api/list
app.Get("/api/list", func(c fiber.Ctx) error {
fmt.Println("🥉 Last handler")
return c.SendString("Hello, World 👋!")
})
log.Fatal(app.Listen(":3000"))
}
```
📚 Show more code examples
### Views Engines
📖 [Config](https://docs.gofiber.io/api/fiber#config)
📖 [Engines](https://github.com/gofiber/template)
📖 [Render](https://docs.gofiber.io/api/ctx#render)
Fiber defaults to the [html/template](https://pkg.go.dev/html/template/) when no view engine is set.
If you want to execute partials or use a different engine like [amber](https://github.com/eknkc/amber), [handlebars](https://github.com/aymerick/raymond), [mustache](https://github.com/cbroglie/mustache), or [pug](https://github.com/Joker/jade), etc., check out our [Template](https://github.com/gofiber/template) package that supports multiple view engines.
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/template/pug"
)
func main() {
// Initialize a new Fiber app with Pug template engine
app := fiber.New(fiber.Config{
Views: pug.New("./views", ".pug"),
})
// Define a route that renders the "home.pug" template
app.Get("/", func(c fiber.Ctx) error {
return c.Render("home", fiber.Map{
"title": "Homepage",
"year": 1999,
})
})
log.Fatal(app.Listen(":3000"))
}
```
### Grouping Routes into Chains
📖 [Group](https://docs.gofiber.io/api/app#group)
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
func middleware(c fiber.Ctx) error {
log.Println("Middleware executed")
return c.Next()
}
func handler(c fiber.Ctx) error {
return c.SendString("Handler response")
}
func main() {
app := fiber.New()
// Root API group with middleware
api := app.Group("/api", middleware) // /api
// API v1 routes
v1 := api.Group("/v1", middleware) // /api/v1
v1.Get("/list", handler) // /api/v1/list
v1.Get("/user", handler) // /api/v1/user
// API v2 routes
v2 := api.Group("/v2", middleware) // /api/v2
v2.Get("/list", handler) // /api/v2/list
v2.Get("/user", handler) // /api/v2/user
log.Fatal(app.Listen(":3000"))
}
```
### Middleware Logger
📖 [Logger](https://docs.gofiber.io/api/middleware/logger)
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/logger"
)
func main() {
app := fiber.New()
// Use Logger middleware
app.Use(logger.New())
// Define routes
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, Logger!")
})
log.Fatal(app.Listen(":3000"))
}
```
### Cross-Origin Resource Sharing (CORS)
📖 [CORS](https://docs.gofiber.io/api/middleware/cors)
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/cors"
)
func main() {
app := fiber.New()
// Use CORS middleware with default settings
app.Use(cors.New())
// Define routes
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("CORS enabled!")
})
log.Fatal(app.Listen(":3000"))
}
```
Check CORS by passing any domain in `Origin` header:
```bash
curl -H "Origin: http://example.com" --verbose http://localhost:3000
```
### Custom 404 Response
📖 [HTTP Methods](https://docs.gofiber.io/api/ctx#status)
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
// Define routes
app.Get("/", static.New("./public"))
app.Get("/demo", func(c fiber.Ctx) error {
return c.SendString("This is a demo page!")
})
app.Post("/register", func(c fiber.Ctx) error {
return c.SendString("Registration successful!")
})
// Middleware to handle 404 Not Found
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusNotFound) // => 404 "Not Found"
})
log.Fatal(app.Listen(":3000"))
}
```
### JSON Response
📖 [JSON](https://docs.gofiber.io/api/ctx#json)
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
type User struct {
Name string `json:"name"`
Age int `json:"age"`
}
func main() {
app := fiber.New()
// Route that returns a JSON object
app.Get("/user", func(c fiber.Ctx) error {
return c.JSON(&User{"John", 20})
// => {"name":"John", "age":20}
})
// Route that returns a JSON map
app.Get("/json", func(c fiber.Ctx) error {
return c.JSON(fiber.Map{
"success": true,
"message": "Hi John!",
})
// => {"success":true, "message":"Hi John!"}
})
log.Fatal(app.Listen(":3000"))
}
```
### WebSocket Upgrade
📖 [Websocket](https://github.com/gofiber/websocket)
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/websocket"
)
func main() {
app := fiber.New()
// WebSocket route
app.Get("/ws", websocket.New(func(c *websocket.Conn) {
defer c.Close()
for {
// Read message from client
mt, msg, err := c.ReadMessage()
if err != nil {
log.Println("read:", err)
break
}
log.Printf("recv: %s", msg)
// Write message back to client
err = c.WriteMessage(mt, msg)
if err != nil {
log.Println("write:", err)
break
}
}
}))
log.Fatal(app.Listen(":3000"))
// Connect via WebSocket at ws://localhost:3000/ws
}
```
### Server-Sent Events
📖 [More Info](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events)
```go title="Example"
package main
import (
"bufio"
"fmt"
"log"
"time"
"github.com/gofiber/fiber/v3"
"github.com/valyala/fasthttp"
)
func main() {
app := fiber.New()
// Server-Sent Events route
app.Get("/sse", func(c fiber.Ctx) error {
c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.Context().SetBodyStreamWriter(func(w *bufio.Writer) {
var i int
for {
i++
msg := fmt.Sprintf("%d - the time is %v", i, time.Now())
fmt.Fprintf(w, "data: Message: %s\n\n", msg)
fmt.Println(msg)
w.Flush()
time.Sleep(5 * time.Second)
}
})
return nil
})
log.Fatal(app.Listen(":3000"))
}
```
### Recover Middleware
📖 [Recover](https://docs.gofiber.io/api/middleware/recover)
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/recover"
)
func main() {
app := fiber.New()
// Use Recover middleware to handle panics gracefully
app.Use(recover.New())
// Route that intentionally panics
app.Get("/", func(c fiber.Ctx) error {
panic("normally this would crash your app")
})
log.Fatal(app.Listen(":3000"))
}
```
### Using Trusted Proxy
📖 [Config](https://docs.gofiber.io/api/fiber#config)
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New(fiber.Config{
// Configure trusted proxies - WARNING: Only trust proxies you control
// Using TrustProxy: true with unrestricted IPs can lead to IP spoofing
TrustProxy: true,
TrustProxyConfig: fiber.TrustProxyConfig{
Proxies: []string{"10.0.0.0/8", "172.16.0.0/12"}, // Example: Internal network ranges only
},
ProxyHeader: fiber.HeaderXForwardedFor,
})
// Define routes
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Trusted Proxy Configured!")
})
log.Fatal(app.Listen(":3000"))
}
```
## 🧬 Internal Middleware
Here is a list of middleware that are included within the Fiber framework.
| Middleware | Description |
|--------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [adaptor](https://github.com/gofiber/fiber/tree/main/middleware/adaptor) | Converter for net/http handlers to/from Fiber request handlers. |
| [basicauth](https://github.com/gofiber/fiber/tree/main/middleware/basicauth) | Provides HTTP basic authentication. It calls the next handler for valid credentials and 401 Unauthorized for missing or invalid credentials. |
| [cache](https://github.com/gofiber/fiber/tree/main/middleware/cache) | Intercept and cache HTTP responses. |
| [compress](https://github.com/gofiber/fiber/tree/main/middleware/compress) | Compression middleware for Fiber, with support for `deflate`, `gzip`, `brotli` and `zstd`. |
| [cors](https://github.com/gofiber/fiber/tree/main/middleware/cors) | Enable cross-origin resource sharing (CORS) with various options. |
| [csrf](https://github.com/gofiber/fiber/tree/main/middleware/csrf) | Protect from CSRF exploits. |
| [earlydata](https://github.com/gofiber/fiber/tree/main/middleware/earlydata) | Adds support for TLS 1.3's early data ("0-RTT") feature. |
| [encryptcookie](https://github.com/gofiber/fiber/tree/main/middleware/encryptcookie) | Encrypt middleware which encrypts cookie values. |
| [envvar](https://github.com/gofiber/fiber/tree/main/middleware/envvar) | Expose environment variables with providing an optional config. |
| [etag](https://github.com/gofiber/fiber/tree/main/middleware/etag) | Allows for caches to be more efficient and save bandwidth, as a web server does not need to resend a full response if the content has not changed. |
| [expvar](https://github.com/gofiber/fiber/tree/main/middleware/expvar) | Serves via its HTTP server runtime exposed variables in the JSON format. |
| [favicon](https://github.com/gofiber/fiber/tree/main/middleware/favicon) | Ignore favicon from logs or serve from memory if a file path is provided. |
| [healthcheck](https://github.com/gofiber/fiber/tree/main/middleware/healthcheck) | Liveness and Readiness probes for Fiber. |
| [helmet](https://github.com/gofiber/fiber/tree/main/middleware/helmet) | Helps secure your apps by setting various HTTP headers. |
| [idempotency](https://github.com/gofiber/fiber/tree/main/middleware/idempotency) | Allows for fault-tolerant APIs where duplicate requests do not erroneously cause the same action performed multiple times on the server-side. |
| [keyauth](https://github.com/gofiber/fiber/tree/main/middleware/keyauth) | Adds support for key based authentication. |
| [limiter](https://github.com/gofiber/fiber/tree/main/middleware/limiter) | Adds Rate-limiting support to Fiber. Use to limit repeated requests to public APIs and/or endpoints such as password reset. |
| [logger](https://github.com/gofiber/fiber/tree/main/middleware/logger) | HTTP request/response logger. |
| [paginate](https://github.com/gofiber/fiber/tree/main/middleware/paginate) | Extracts pagination parameters from query strings. Supports page-based, offset-based, and cursor-based pagination with multi-field sorting. |
| [pprof](https://github.com/gofiber/fiber/tree/main/middleware/pprof) | Serves runtime profiling data in pprof format. |
| [proxy](https://github.com/gofiber/fiber/tree/main/middleware/proxy) | Allows you to proxy requests to multiple servers. |
| [recover](https://github.com/gofiber/fiber/tree/main/middleware/recover) | Recovers from panics anywhere in the stack chain and handles the control to the centralized ErrorHandler. |
| [redirect](https://github.com/gofiber/fiber/tree/main/middleware/redirect) | Redirect middleware. |
| [requestid](https://github.com/gofiber/fiber/tree/main/middleware/requestid) | Adds a request ID to every request. |
| [responsetime](https://github.com/gofiber/fiber/tree/main/middleware/responsetime) | Measures request handling duration and writes it to a configurable response header. |
| [rewrite](https://github.com/gofiber/fiber/tree/main/middleware/rewrite) | Rewrites the URL path based on provided rules. It can be helpful for backward compatibility or just creating cleaner and more descriptive links. |
| [session](https://github.com/gofiber/fiber/tree/main/middleware/session) | Session middleware. NOTE: This middleware uses our Storage package. |
| [skip](https://github.com/gofiber/fiber/tree/main/middleware/skip) | Skip middleware that skips a wrapped handler if a predicate is true. |
| [static](https://github.com/gofiber/fiber/tree/main/middleware/static) | Static middleware for Fiber that serves static files such as **images**, **CSS**, and **JavaScript**. |
| [timeout](https://github.com/gofiber/fiber/tree/main/middleware/timeout) | Adds a max time for a request and forwards to ErrorHandler if it is exceeded. |
## 🧬 External Middleware
List of externally hosted middleware modules and maintained by the [Fiber team](https://github.com/orgs/gofiber/people).
| Middleware | Description |
| :------------------------------------------------ | :-------------------------------------------------------------------------------------------------------------------- |
| [contrib](https://github.com/gofiber/contrib) | Third-party middlewares |
| [storage](https://github.com/gofiber/storage) | Premade storage drivers that implement the Storage interface, designed to be used with various Fiber middlewares. |
| [template](https://github.com/gofiber/template) | This package contains 9 template engines that can be used with Fiber. |
## 🕶️ Awesome List
For more articles, middlewares, examples, or tools, check our [awesome list](https://github.com/gofiber/awesome-fiber).
## 👍 Contribute
If you want to say **Thank You** and/or support the active development of `Fiber`:
1. Add a [GitHub Star](https://github.com/gofiber/fiber/stargazers) to the project.
2. Tweet about the project [on your 𝕏 (Twitter)](https://x.com/intent/tweet?text=Fiber%20is%20an%20Express%20inspired%20%23web%20%23framework%20built%20on%20top%20of%20Fasthttp%2C%20the%20fastest%20HTTP%20engine%20for%20%23Go.%20Designed%20to%20ease%20things%20up%20for%20%23fast%20development%20with%20zero%20memory%20allocation%20and%20%23performance%20in%20mind%20%F0%9F%9A%80%20https%3A%2F%2Fgithub.com%2Fgofiber%2Ffiber).
3. Write a review or tutorial on [Medium](https://medium.com/), [Dev.to](https://dev.to/) or your personal blog.
4. Support the project by donating a [cup of coffee](https://buymeacoff.ee/fenny).
## 💻 Development
To ensure your contributions are ready for a Pull Request, please use the following `Makefile` commands. These tools help maintain code quality and consistency.
- **make help**: Display available commands.
- **make audit**: Conduct quality checks.
- **make benchmark**: Benchmark code performance.
- **make coverage**: Generate test coverage report.
- **make format**: Automatically format code.
- **make lint**: Run lint checks.
- **make test**: Execute all tests.
- **make tidy**: Tidy dependencies.
Run these commands to ensure your code adheres to project standards and best practices.
## ☕ Supporters
Fiber is an open-source project that runs on donations to pay the bills, e.g., our domain name, GitBook, Netlify, and serverless hosting. If you want to support Fiber, you can ☕ [**buy a coffee here**](https://buymeacoff.ee/fenny).
| | User | Donation |
| ---------------------------------------------------------- | ------------------------------------------------ | -------- |
|  | [@destari](https://github.com/destari) | ☕ x 10 |
|  | [@dembygenesis](https://github.com/dembygenesis) | ☕ x 5 |
| | [@thomasvvugt](https://github.com/thomasvvugt) | ☕ x 5 |
|  | [@hendratommy](https://github.com/hendratommy) | ☕ x 5 |
|  | [@ekaputra07](https://github.com/ekaputra07) | ☕ x 5 |
|  | [@jorgefuertes](https://github.com/jorgefuertes) | ☕ x 5 |
|  | [@candidosales](https://github.com/candidosales) | ☕ x 5 |
|  | [@l0nax](https://github.com/l0nax) | ☕ x 3 |
|  | [@bihe](https://github.com/bihe) | ☕ x 3 |
|  | [@justdave](https://github.com/justdave) | ☕ x 3 |
|  | [@koddr](https://github.com/koddr) | ☕ x 1 |
|  | [@lapolinar](https://github.com/lapolinar) | ☕ x 1 |
|  | [@diegowifi](https://github.com/diegowifi) | ☕ x 1 |
|  | [@ssimk0](https://github.com/ssimk0) | ☕ x 1 |
|  | [@raymayemir](https://github.com/raymayemir) | ☕ x 1 |
|  | [@melkorm](https://github.com/melkorm) | ☕ x 1 |
|  | [@marvinjwendt](https://github.com/marvinjwendt) | ☕ x 1 |
|  | [@toishy](https://github.com/toishy) | ☕ x 1 |
## 💻 Code Contributors
## ⭐️ Stargazers
## 🧾 License
Copyright (c) 2019-present [Fenny](https://github.com/fenny) and [Contributors](https://github.com/gofiber/fiber/graphs/contributors). `Fiber` is free and open-source software licensed under the [MIT License](https://github.com/gofiber/fiber/blob/main/LICENSE). Official logo was created by [Vic Shóstak](https://github.com/koddr) and distributed under [Creative Commons](https://creativecommons.org/licenses/by-sa/4.0/) license (CC BY-SA 4.0 International).
================================================
FILE: .github/SECURITY.md
================================================
# Security Policy
1. [Supported Versions](#versions)
2. [Reporting security problems to Fiber](#reporting)
3. [Security Points of Contact](#contact)
4. [Incident Response Process](#process)
## Supported Versions
The table below shows the supported versions for Fiber which include security updates.
| Version | Supported |
| --------- | ------------------ |
| >= 1.12.6 | :white_check_mark: |
| < 1.12.6 | :x: |
## Reporting security problems to Fiber
**DO NOT CREATE AN ISSUE** to report a security problem. Instead, please
send us an e-mail at `team@gofiber.io` or join our discord server via
[this invite link](https://gofiber.io/discord) and send a private message
to any of the maintainers.
## Security Points of Contact
For security-related matters, please contact any of the
[@maintainers](https://github.com/orgs/gofiber/teams/maintainers).
The maintainers are the only persons with administrative access to Fiber's source code
and respond to security incident reports as fast as possible, within one business day
at the latest.
## Incident Response Process
In case an incident is discovered or reported, we will follow the following
process to contain, respond and remediate:
### 1. Containment
The first step is to find out the root cause, nature and scope of the incident.
- Is still ongoing? If yes, first priority is to stop it.
- Is the incident outside of our influence? If yes, first priority is to contain it.
- Find out knows about the incident and who is affected.
- Find out what data was potentially exposed.
### 2. Response
After the initial assessment and containment to our best abilities, we will
document all actions taken in a response plan.
We will create a comment in the official `#announcements` channel to inform users about
the incident and what actions we took to contain it.
### 3. Remediation
Once the incident is confirmed to be resolved, we will summarize the lessons
learned from the incident and create a list of actions we will take to prevent
it from happening again.
### Secure accounts with access
The [Fiber Organization](https://github.com/gofiber) requires 2FA authorization
for all of it's members.
### Critical Updates And Security Notices
We learn about critical software updates and security threats from these sources
1. GitHub Security Alerts
2. GitHub: [https://status.github.com/](https://status.github.com/) & [@githubstatus](https://twitter.com/githubstatus)
================================================
FILE: .github/codecov.yml
================================================
coverage:
status:
project:
default:
target: auto
threshold: 0.5%
base: auto
patch:
default:
target: auto
threshold: 0.5%
base: auto
ignore:
# Ignore generated root files
- "*_msgp.go"
- "*_msgp_test.go"
- "*_gen.go"
# Ignore generated files below root
- "**/*_msgp.go"
- "**/*_msgp_test.go"
- "**/*_gen.go"
# Ignore internal and docs
- "internal/**"
- "docs/**"
================================================
FILE: .github/config.yml
================================================
# Configuration for new-issue-welcome - https://github.com/behaviorbot/new-issue-welcome
# Comment to be posted to on first time issues
newIssueWelcomeComment: >
Thanks for opening your first issue here! 🎉 Be sure to follow the issue template!
If you need help or want to chat with us, join us on Discord https://gofiber.io/discord
# Configuration for new-pr-welcome - https://github.com/behaviorbot/new-pr-welcome
# Comment to be posted to on PRs from first time contributors in your repository
newPRWelcomeComment: >
Thanks for opening this pull request! 🎉 Please check out our contributing guidelines.
If you need help or want to chat with us, join us on Discord https://gofiber.io/discord
# Configuration for first-pr-merge - https://github.com/behaviorbot/first-pr-merge
# Comment to be posted to on pull requests merged by a first time user
firstPRMergeComment: >
Congrats on merging your first pull request! 🎉 We here at Fiber are proud of you!
If you need help or want to chat with us, join us on Discord https://gofiber.io/discord
================================================
FILE: .github/copilot-instructions.md
================================================
# Copilot Usage
When modifying code, always perform these steps:
1. **Ensure code quality**
- `make format` to format the project.
- `make lint` for static analysis.
- `make test` to run the test suite.
2. **Maintain documentation**
Review and update the contents of the `docs` folder if necessary.
3. **Check Markdown**
- Finish by running `make markdown` to lint all Markdown files.
================================================
FILE: .github/copilot-setup-steps.yml
================================================
steps:
- run: |
if [ -d vendor ] || go list -m -mod=readonly all; then
echo "Dependencies already present"
else
go mod tidy && go mod download && go mod vendor
fi
- run: |
go install gotest.tools/gotestsum@latest
go install golang.org/x/vuln/cmd/govulncheck@latest
go install mvdan.cc/gofumpt@latest
go install github.com/tinylib/msgp@latest
go install github.com/vburenin/ifacemaker@975a95966976eeb2d4365a7fb236e274c54da64c
go install github.com/dkorunic/betteralign/cmd/betteralign@latest
- run: go mod tidy
================================================
FILE: .github/dependabot.yml
================================================
# https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "gomod"
directory: "/" # Location of package manifests
labels:
- "🤖 Dependencies"
schedule:
interval: "daily"
allow:
# Allow both direct and indirect updates for all packages.
- dependency-type: "all"
groups:
fasthttp-modules:
patterns:
- "github.com/valyala/fasthttp"
- "github.com/valyala/fasthttp/**"
golang-modules:
patterns:
- "golang.org/x/**"
valyala-utils-modules:
patterns:
- "github.com/valyala/bytebufferpool"
- "github.com/valyala/tcplisten"
google-modules:
patterns:
- "github.com/google/**"
- "google.golang.org/**"
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: daily
labels:
- "🤖 Dependencies"
================================================
FILE: .github/index.html
================================================
Test file
Hello, World!
================================================
FILE: .github/labeler.yml
================================================
_extends: .github
labels:
- label: 'v3'
matcher:
baseBranch: 'main'
title: '/(v3)/i'
body: '/(v3)/i'
- label: 'v2'
matcher:
baseBranch: 'v2'
title: '/(v2)/i'
================================================
FILE: .github/pull_request_template.md
================================================
# Description
Please provide a clear and concise description of the changes you've made and the problem they address. Include the purpose of the change, any relevant issues it solves, and the benefits it brings to the project. If this change introduces new features or adjustments, highlight them here.
Fixes # (issue)
## Changes introduced
List the new features or adjustments introduced in this pull request. Provide details on benchmarks, documentation updates, changelog entries, and if applicable, the migration guide.
- [ ] Benchmarks: Describe any performance benchmarks and improvements related to the changes.
- [ ] Documentation Update: Detail the updates made to the documentation and links to the changed files.
- [ ] Changelog/What's New: Include a summary of the additions for the upcoming release notes.
- [ ] Migration Guide: If necessary, provide a guide or steps for users to migrate their existing code to accommodate these changes.
- [ ] API Alignment with Express: Explain how the changes align with the Express API.
- [ ] API Longevity: Discuss the steps taken to ensure that the new or updated APIs are consistent and not prone to breaking changes.
- [ ] Examples: Provide examples demonstrating the new features or changes in action.
## Type of change
Please delete options that are not relevant.
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Enhancement (improvement to existing features and functionality)
- [ ] Documentation update (changes to documentation)
- [ ] Performance improvement (non-breaking change which improves efficiency)
- [ ] Code consistency (non-breaking change which improves code reliability and robustness)
## Checklist
Before you submit your pull request, please make sure you meet these requirements:
- [ ] Followed the inspiration of the Express.js framework for new functionalities, making them similar in usage.
- [ ] Conducted a self-review of the code and provided comments for complex or critical parts.
- [ ] Updated the documentation in the `/docs/` directory for [Fiber's documentation](https://docs.gofiber.io/).
- [ ] Added or updated unit tests to validate the effectiveness of the changes or new features.
- [ ] Ensured that new and existing unit tests pass locally with the changes.
- [ ] Verified that any new dependencies are essential and have been agreed upon by the maintainers/community.
- [ ] Aimed for optimal performance with minimal allocations in the new code.
- [ ] Provided benchmarks for the new code to analyze and improve upon.
## Commit formatting
Please use emojis in commit messages for an easy way to identify the purpose or intention of a commit. Check out the emoji cheatsheet here: [CONTRIBUTING.md](https://github.com/gofiber/fiber/blob/main/.github/CONTRIBUTING.md#pull-requests-or-commits)
================================================
FILE: .github/release-drafter.yml
================================================
name-template: 'v$RESOLVED_VERSION'
tag-template: 'v$RESOLVED_VERSION'
commitish: main
filter-by-commitish: true
include-labels:
- 'v3'
exclude-labels:
- 'v2'
categories:
- title: '❗ Breaking Changes'
labels:
- '❗ BreakingChange'
- title: '🚀 New'
labels:
- '✏️ Feature'
- '📝 Proposal'
- title: '🧹 Updates'
labels:
- '🧹 Updates'
- '⚡️ Performance'
- title: '🐛 Fixes'
labels:
- '☢️ Bug'
- title: '🛠️ Maintenance'
labels:
- '🤖 Dependencies'
- title: '📚 Documentation'
labels:
- '📒 Documentation'
change-template: '- $TITLE (#$NUMBER)'
change-title-escapes: '\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks.
exclude-contributors:
- dependabot
- dependabot[bot]
version-resolver:
major:
labels:
- '❗ BreakingChange'
minor:
labels:
- '✏️ Feature'
patch:
labels:
- '📒 Documentation'
- '☢️ Bug'
- '🤖 Dependencies'
- '🧹 Updates'
- '⚡️ Performance'
default: patch
template: |
$CHANGES
**📒 Documentation**: https://docs.gofiber.io/next/
**💬 Discord**: https://gofiber.io/discord
**Full Changelog**: https://github.com/$OWNER/$REPOSITORY/compare/$PREVIOUS_TAG...v$RESOLVED_VERSION
Thank you $CONTRIBUTORS for making this release possible.
================================================
FILE: .github/release.yml
================================================
# .github/release.yml
changelog:
categories:
- title: '❗ Breaking Changes'
labels:
- '❗ BreakingChange'
- title: '🚀 New Features'
labels:
- '✏️ Feature'
- '📝 Proposal'
- title: '🧹 Updates'
labels:
- '🧹 Updates'
- '⚡️ Performance'
- title: '🐛 Bug Fixes'
labels:
- '☢️ Bug'
- title: '🛠️ Maintenance'
labels:
- '🤖 Dependencies'
- title: '📚 Documentation'
labels:
- '📒 Documentation'
- title: 'Other Changes'
labels:
- '*'
================================================
FILE: .github/scripts/sync_docs.sh
================================================
#!/usr/bin/env bash
set -e
# Some env variables
BRANCH="main"
REPO_URL="github.com/gofiber/docs.git"
AUTHOR_EMAIL="github-actions[bot]@users.noreply.github.com"
AUTHOR_USERNAME="github-actions[bot]"
VERSION_FILE="versions.json"
REPO_DIR="core"
COMMIT_URL="https://github.com/gofiber/fiber"
DOCUSAURUS_COMMAND="npm run docusaurus -- docs:version"
# Set commit author
git config --global user.email "${AUTHOR_EMAIL}"
git config --global user.name "${AUTHOR_USERNAME}"
git clone https://${TOKEN}@${REPO_URL} fiber-docs
# Handle push event
if [ "$EVENT" == "push" ]; then
latest_commit=$(git rev-parse --short HEAD)
#log_output=$(git log --oneline ${BRANCH} HEAD~1..HEAD --name-status -- docs/)
#if [[ $log_output != "" ]]; then
cp -a docs/* fiber-docs/docs/${REPO_DIR}
#fi
# Handle release event
elif [ "$EVENT" == "release" ]; then
major_version="${TAG_NAME%%.*}"
echo "Major version: $major_version"
# Form new version name
new_version="${major_version}.x"
echo "New version: $new_version"
cd fiber-docs/ || true
npm ci
# Check if contrib_versions.json exists and modify it if required
if [[ -f $VERSION_FILE ]]; then
echo "Modifying version file: $VERSION_FILE"
jq --arg new_version "$new_version" 'del(.[] | select(. == $new_version))' $VERSION_FILE > temp.json && mv temp.json $VERSION_FILE
fi
# Run docusaurus versioning command
$DOCUSAURUS_COMMAND "${new_version}"
if [[ -f $VERSION_FILE ]]; then
echo "Sorting version file: $VERSION_FILE"
jq 'sort | reverse' ${VERSION_FILE} > temp.json && mv temp.json ${VERSION_FILE}
fi
fi
# Push changes
cd fiber-docs/ || true
git add .
if [[ $EVENT == "push" ]]; then
git commit -m "Add docs from ${COMMIT_URL}/commit/${latest_commit}"
elif [[ $EVENT == "release" ]]; then
git commit -m "Sync docs for release ${COMMIT_URL}/releases/tag/${TAG_NAME}"
fi
MAX_RETRIES=5
DELAY=5
retry=0
while ((retry < MAX_RETRIES)); do
git push https://${TOKEN}@${REPO_URL} && break
retry=$((retry + 1))
git pull --rebase
sleep $DELAY
done
if ((retry == MAX_RETRIES)); then
echo "Failed to push after $MAX_RETRIES attempts. Exiting with 1."
exit 1
fi
================================================
FILE: .github/testdata/ca-chain.cert.pem
================================================
-----BEGIN CERTIFICATE-----
MIIFeTCCA2GgAwIBAgIDEAISMA0GCSqGSIb3DQEBCwUAMFYxCzAJBgNVBAYTAlVT
MQ8wDQYDVQQIDAZEZW5pYWwxFDASBgNVBAcMC1NwcmluZ2ZpZWxkMQwwCgYDVQQK
DANEaXMxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0yMjAyMDgyMDA3MjlaFw0zMjAy
MDYyMDA3MjlaMEAxCzAJBgNVBAYTAlVTMQ8wDQYDVQQIDAZEZW5pYWwxDDAKBgNV
BAoMA0RpczESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEFAAOC
Ag8AMIICCgKCAgEA5Cho0kbBDi1cy8bURStc95hK2RzjBQMd2hN5gFxZdF5knBfC
LSiPMxtAn9zJYzYc9+Cq7hIOK19cgG4yKk9GFZaUe+mU4yWxRg1ViSu/jzQ04sVc
JRSbSklXY1RNyxpUtGelxnluUvdvuXXlCPmKob4IsUtI1FTcumG1mzIO+cAzBd1J
KQtNTUO9XfSHYusV/FQO2hIbaXcFgSAg50JJfYZaUw51J07j3vdb6lb1x4rRmIaq
8txrdHo0Y2tXHsq6jry1QrOZfoz4WbYcoID3JU1MC5f1HyR5uYiCA1RJVGnQ3iSX
3yM+gRy3SFPeaASs2useSzGkMr/pDlbcSVmsbXsasBxZq85T1FE8vuY6K4XlU2sN
PyiPrNjDgVkQ8Lbj1B9oKEYKkmSieBx9YwRLarfru1kt+g3kdXuel7DyHpm+j/13
vqjyF9DAyx4wAEZC+DzeqBsbuiDdRkzwFMcKPxYpgSTLawnCjlFapPvE5kGN+O/j
To2qWbWUU/upzBvHu4tnICSapJJ0VqA+7M7yaBAsIWK/yjNTzpHfx0oHudl8wBOG
ySfOE52uouFsp2vs06YpEg2nGn/7Iu0Rbbwt4iFcSZlEnSk0cQlyMZxdj3M2fMKa
/nrRQm7guPbVmBJOFHZuTTiilNSduSsDwCjJkGdJkSVYbj3+eJzKwYstnT8CAwEA
AaNmMGQwHQYDVR0OBBYEFHwli1hTCVJLHPTHWW8O8BCaHci9MB8GA1UdIwQYMBaA
FDlb/7rpDA2ZzsLZmqbW/krUtmGOMBIGA1UdEwEB/wQIMAYBAf8CAQAwDgYDVR0P
AQH/BAQDAgGGMA0GCSqGSIb3DQEBCwUAA4ICAQBtMxa2/w6kGF9cmqpTdQ1La8nY
R4Zoewnn+SCmcSOwCyBC32g0Ry6nKKUpJBpJEid5lBzWveIw4K7pdWvmuqmeMuWI
ilvlCLzqYPigmnEIW96hc6XiQvl9NC5j+SAZSC+4uCNhEUx5pEbE1FU1gIX+szdJ
tLdPwwg63Ce/us6QZ7Tx8qLIr+XU+DrCgjIheQFShtoNYDw0GxEtjeo8vHynj8EZ
+p0OZgqoNlnRbQbllruDFPXDJVI23DVhNpJhT86iQDMtMV53ypMu62LXmdQIKa7l
ITnEMGO626RKqw2kDHt7yinBlt1nHskaeeLya6K/08uJkqRCjzOshJgsjQ3e62vQ
Mht9QvGBCAoY09fIGxRihtTWCFDe7MEnbh1PPYB7cZTOMnL3wxRPzLLYhclX+pt0
bBf7Dn84b3tdC5BFXBJeZMs5QSCvn4yrTew+NvvX3oL6Ny1JDZYaG5PhKf00J6iy
TkXzK2n9U9RX+krPk8fU9Ae1nayD0vrmGaVcBdRQPn4XUuS3LhdlkHfr28z1nF9m
ffd0WBrJlNX9SoKtsMj8VJFZ/nJ0EcCcY1mG3k/IGAY1HUeo4A+C5E/UO3h4+tOL
uqUa8rkl9HoE4fIWdQVxtQjEdATSuJusaK7CFpWH8A0w9VchDx74saiwwGhVyXYk
yBwSA5U88ymkQ7qNJg==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIFnTCCA4WgAwIBAgIUQnfvDIm6z+973AYRTLorZHEQA30wDQYJKoZIhvcNAQEL
BQAwVjELMAkGA1UEBhMCVVMxDzANBgNVBAgMBkRlbmlhbDEUMBIGA1UEBwwLU3By
aW5nZmllbGQxDDAKBgNVBAoMA0RpczESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIy
MDIwODIwMDcyOVoXDTQyMDIwMzIwMDcyOVowVjELMAkGA1UEBhMCVVMxDzANBgNV
BAgMBkRlbmlhbDEUMBIGA1UEBwwLU3ByaW5nZmllbGQxDDAKBgNVBAoMA0RpczES
MBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKC
AgEA+Cf/fKPvI5+Nh81wpxghLrpAjM1MyhdHUDXc8bu7NTNYZ+6ArMqDeKSszTWT
gLV9EeJs57KwjwXIoYZDTcLvpjanrZ2s7JDEqsGl7S6Xr67qzYghlF/GaB3f3lAi
GsvhmDgC4jkdCvkVBKOB4tE0dy6fnNCmIKhhJDje51p90LWFuX5sIKO2trgte6b/
P1PW8rOjedPd5Z+QCG4Mi8JnbJicX1YaOySGMcXHm/eM2wy5I3pEdUreZDFbjB7l
CKa1kFAnDXBiQAoErggMlcXe6C+avB13wYCfi+R+9m2X0svmerSz+oHDCOhvnD52
EE7fBv89VS6pR8mykz1eHiBKkVT1qdmONThUQ8mqwxlo06bZyoykKSxG6ffJnGbM
GkTWlaNEdZuY0pITQLxEX2SwLJuZKfTPFheno83bLCqTOTuWo7h1mLe0ogJMxl//
NHzyoMJ0b1bbrRgsLcMaDn1MsI+gVdRY89+cZ2uedEgrr7fl3KFWiF2S57bxX1fJ
P8HC0bzMny4jMtIf6YUqDtpGDPjZq3PGqcrcO0dVYkuaC96H3xLcvQxgkMDK0sm0
pUbWlECzAag/lxeeC22VnedqgCpiq/9z1b4j884ZkhIJyht0HR7L4I0gO/R/mWUY
8bO9XCMYmP0CjO7u93IlzQ7aSIpWprTHxjpiPelmcO001jkCAwEAAaNjMGEwHQYD
VR0OBBYEFDlb/7rpDA2ZzsLZmqbW/krUtmGOMB8GA1UdIwQYMBaAFDlb/7rpDA2Z
zsLZmqbW/krUtmGOMA8GA1UdEwEB/wQFMAMBAf8wDgYDVR0PAQH/BAQDAgGGMA0G
CSqGSIb3DQEBCwUAA4ICAQDk02Tu0ypqnS897WBx98l2nYIrEhcrJg8ZMmSwEa9J
7TANofzsP9931YoQMh6BI6hB3OkyL6FYTUDykpGMMasojtL/F2iXEsjema2ilZ/7
hNAZ+j5mBemMwXfkfmRguXvnl7EWaZETgEoxhcOTYoYUYqDcyzuwK63fOs+YA5ke
O8E3F1aLHzLpqVpiG7t740L7LdibNPko9JOd31Gqcq3nhXMf6/rOdL8VSj/F+4BG
opgJBruJV9NxWRI/b0G6eImvaYL/Ljfd8wzwNpmYkNkHbhAiaHeXJQ05mebmr2Dr
wne9QeSJkXCs/K5A/8+0CYNN4homt8xNNN02SnJ5e6nv1A1ntMW9n6n2KYo87tz9
VmqWXg7Y1BqXj287WRaWPJsBa2RBP1W2d3BQfHKJfu15blyXaczTi87WayEsBnQq
TXy+1QP0IwQerSTOxdW25UoJmH18SRbLEIQs9Htvcpz2AncTYjeiLFa15FO2r5hP
LYc9QOKn6yIZP9lYztleEqOLTmHnRnFcupDol+/x88d+kVLqmXDiKmWbVIz7C735
xgImsyrCPPYYiEA7/yaP5G1o5XU93kRPrtg/7jjyF+uBZ70fcbED3prpuiJYrL0O
gvQUgmGUU30mPHjAKkEACeRXtoqucRDxvIkBb5zUvZG8RmSFae5siAWwLD7D7VJa
IA==
-----END CERTIFICATE-----
================================================
FILE: .github/testdata/fs/css/style.css
================================================
h1 {
color: red;
text-align: center;
}
================================================
FILE: .github/testdata/fs/css/test/style2.css
================================================
h1 {
color: black;
}
================================================
FILE: .github/testdata/fs/index.html
================================================
Document
", string(c.Response().Body()))
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Get_Location_From_Route -benchmem -count=4
func Benchmark_Ctx_Get_Location_From_Route(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
app.Get("/user/:name", func(c Ctx) error {
return c.SendString(c.Params("name"))
}).Name("User")
var err error
var location string
for b.Loop() {
route := app.GetRoute("User")
location, err = c.getLocationFromRoute(&route, Map{"name": "fiber"})
}
require.Equal(b, "/user/fiber", location)
require.NoError(b, err)
}
// go test -run Test_Ctx_Get_Location_From_Route_name
func Test_Ctx_Get_Location_From_Route_name(t *testing.T) {
t.Parallel()
t.Run("case-insensitive", func(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
app.Get("/user/:name", func(c Ctx) error {
return c.SendString(c.Params("name"))
}).Name("User")
location, err := c.GetRouteURL("User", Map{"name": "fiber"})
require.NoError(t, err)
require.Equal(t, "/user/fiber", location)
location, err = c.GetRouteURL("User", Map{"Name": "fiber"})
require.NoError(t, err)
require.Equal(t, "/user/fiber", location)
})
t.Run("case-sensitive", func(t *testing.T) {
t.Parallel()
app := New(Config{CaseSensitive: true})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
app.Get("/user/:name", func(c Ctx) error {
return c.SendString(c.Params("name"))
}).Name("User")
location, err := c.GetRouteURL("User", Map{"name": "fiber"})
require.NoError(t, err)
require.Equal(t, "/user/fiber", location)
location, err = c.GetRouteURL("User", Map{"Name": "fiber"})
require.NoError(t, err)
require.Equal(t, "/user/", location)
})
}
// go test -run Test_Ctx_Get_Location_From_Route_name_Optional_greedy
func Test_Ctx_Get_Location_From_Route_name_Optional_greedy(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
app.Get("/:phone/*/send/*", func(c Ctx) error {
return c.SendString("Phone: " + c.Params("phone") + "\nFirst Param: " + c.Params("*1") + "\nSecond Param: " + c.Params("*2"))
}).Name("SendSms")
location, err := c.GetRouteURL("SendSms", Map{
"phone": "23456789",
"*1": "sms",
"*2": "test-msg",
})
require.NoError(t, err)
require.Equal(t, "/23456789/sms/send/test-msg", location)
}
// go test -run Test_Ctx_Get_Location_From_Route_name_Optional_greedy_one_param
func Test_Ctx_Get_Location_From_Route_name_Optional_greedy_one_param(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
app.Get("/:phone/*/send", func(c Ctx) error {
return c.SendString("Phone: " + c.Params("phone") + "\nFirst Param: " + c.Params("*1"))
}).Name("SendSms")
location, err := c.GetRouteURL("SendSms", Map{
"phone": "23456789",
"*": "sms",
})
require.NoError(t, err)
require.Equal(t, "/23456789/sms/send", location)
}
type errorTemplateEngine struct{}
func (errorTemplateEngine) Render(_ io.Writer, _ string, _ any, _ ...string) error {
return errors.New("errorTemplateEngine")
}
func (errorTemplateEngine) Load() error { return nil }
// go test -run Test_Ctx_Render_Engine_Error
func Test_Ctx_Render_Engine_Error(t *testing.T) {
t.Parallel()
app := New()
app.config.Views = errorTemplateEngine{}
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.Render("index.tmpl", nil)
require.Error(t, err)
}
// go test -run Test_Ctx_Render_Go_Template
func Test_Ctx_Render_Go_Template(t *testing.T) {
t.Parallel()
file, err := os.CreateTemp(os.TempDir(), "fiber")
require.NoError(t, err)
defer func() {
removeErr := os.Remove(file.Name())
require.NoError(t, removeErr)
}()
_, err = file.WriteString("template")
require.NoError(t, err)
err = file.Close()
require.NoError(t, err)
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err = c.Render(file.Name(), nil)
require.NoError(t, err)
require.Equal(t, "template", string(c.Response().Body()))
}
// go test -run Test_Ctx_Send
func Test_Ctx_Send(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
require.NoError(t, c.Send([]byte("Hello, World")))
require.NoError(t, c.Send([]byte("Don't crash please")))
require.NoError(t, c.Send([]byte("1337")))
require.Equal(t, "1337", string(c.Response().Body()))
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Send -benchmem -count=4
func Benchmark_Ctx_Send(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
byt := []byte("Hello, World!")
b.ReportAllocs()
var err error
for b.Loop() {
err = c.Send(byt)
}
require.NoError(b, err)
require.Equal(b, "Hello, World!", string(c.Response().Body()))
}
// go test -run Test_Ctx_SendStatus
func Test_Ctx_SendStatus(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.SendStatus(415)
require.NoError(t, err)
require.Equal(t, 415, c.Response().StatusCode())
require.Equal(t, "Unsupported Media Type", string(c.Response().Body()))
}
func Test_Ctx_SendStatusNoBodyResponses(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
status int
}{
{
name: "Informational",
status: StatusContinue,
},
{
name: "Processing",
status: StatusProcessing,
},
{
name: "SwitchingProtocols",
status: StatusSwitchingProtocols,
},
{
name: "EarlyHints",
status: StatusEarlyHints,
},
{
name: "NoContent",
status: StatusNoContent,
},
{
name: "ResetContent",
status: StatusResetContent,
},
{
name: "NotModified",
status: StatusNotModified,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Response().SetBodyString("preset body")
err := c.SendStatus(testCase.status)
require.NoError(t, err)
require.Empty(t, c.Response().Body())
require.Equal(t, 0, c.Response().Header.ContentLength())
})
}
}
// go test -run Test_Ctx_SendString
func Test_Ctx_SendString(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.SendString("Don't crash please")
require.NoError(t, err)
require.Equal(t, "Don't crash please", string(c.Response().Body()))
}
// go test -run Test_Ctx_SendStream
func Test_Ctx_SendStream(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.SendStream(bytes.NewReader([]byte("Don't crash please")))
require.NoError(t, err)
require.Equal(t, "Don't crash please", string(c.Response().Body()))
err = c.SendStream(bytes.NewReader([]byte("Don't crash please")), len([]byte("Don't crash please")))
require.NoError(t, err)
require.Equal(t, "Don't crash please", string(c.Response().Body()))
err = c.SendStream(bufio.NewReader(bytes.NewReader([]byte("Hello bufio"))))
require.NoError(t, err)
require.Equal(t, "Hello bufio", string(c.Response().Body()))
}
// go test -run Test_Ctx_SendStreamWriter
func Test_Ctx_SendStreamWriter(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.SendStreamWriter(func(w *bufio.Writer) {
w.WriteString("Don't crash please") //nolint:errcheck // It is fine to ignore the error
})
require.NoError(t, err)
require.Equal(t, "Don't crash please", string(c.Response().Body()))
err = c.SendStreamWriter(func(w *bufio.Writer) {
for lineNum := 1; lineNum <= 5; lineNum++ {
fmt.Fprintf(w, "Line %d\n", lineNum)
if flushErr := w.Flush(); flushErr != nil {
t.Errorf("unexpected error: %s", flushErr)
return
}
}
})
require.NoError(t, err)
require.Equal(t, "Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n", string(c.Response().Body()))
err = c.SendStreamWriter(func(_ *bufio.Writer) {})
require.NoError(t, err)
require.Empty(t, c.Response().Body())
}
// go test -run Test_Ctx_SendStreamWriter_Interrupted
func Test_Ctx_SendStreamWriter_Interrupted(t *testing.T) {
t.Parallel()
app := New()
var flushed atomic.Int32
var flushErrLine atomic.Int32
app.Get("/", func(c Ctx) error {
return c.SendStreamWriter(func(w *bufio.Writer) {
for lineNum := 1; lineNum <= 5; lineNum++ {
fmt.Fprintf(w, "Line %d\n", lineNum)
if err := w.Flush(); err != nil {
flushErrLine.Store(int32(lineNum)) //nolint:gosec // G115 - lineNum is 1-5, fits int32
return
}
if lineNum <= 3 {
flushed.Add(1)
}
if lineNum == 3 {
time.Sleep(500 * time.Millisecond)
}
}
})
})
req := httptest.NewRequest(MethodGet, "/", http.NoBody)
testConfig := TestConfig{
// allow enough time for three lines to flush before
// the test connection is closed but stop before the
// fourth line is sent
Timeout: 200 * time.Millisecond,
FailOnTimeout: true, // Changed to true to test interrupted behavior
}
resp, err := app.Test(req, testConfig)
// With FailOnTimeout: true, we should get a timeout error
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
require.Nil(t, resp)
}
// go test -run Test_Ctx_Set
func Test_Ctx_Set(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Set("X-1", "1")
c.Set("X-2", "2")
c.Set("X-3", "3")
c.Set("X-3", "1337")
require.Equal(t, "1", string(c.Response().Header.Peek("x-1")))
require.Equal(t, "2", string(c.Response().Header.Peek("x-2")))
require.Equal(t, "1337", string(c.Response().Header.Peek("x-3")))
}
// go test -run Test_Ctx_Set_Splitter
func Test_Ctx_Set_Splitter(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Set("Location", "foo\r\nSet-Cookie:%20SESSIONID=MaliciousValue\r\n")
h := string(c.Response().Header.Peek("Location"))
require.NotContains(t, h, "\r\n")
c.Set("Location", "foo\nSet-Cookie:%20SESSIONID=MaliciousValue\n")
h = string(c.Response().Header.Peek("Location"))
require.NotContains(t, h, "\n")
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Set -benchmem -count=4
func Benchmark_Ctx_Set(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
val := "1431-15132-3423"
b.ReportAllocs()
for b.Loop() {
c.Set(HeaderXRequestID, val)
}
}
// go test -run Test_Ctx_Status
func Test_Ctx_Status(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Status(400)
require.Equal(t, 400, c.Response().StatusCode())
err := c.Status(415).Send([]byte("Hello, World"))
require.NoError(t, err)
require.Equal(t, 415, c.Response().StatusCode())
require.Equal(t, "Hello, World", string(c.Response().Body()))
}
// go test -run Test_Ctx_Type
func Test_Ctx_Type(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Type(".json")
require.Equal(t, "application/json; charset=utf-8", string(c.Response().Header.Peek("Content-Type")))
c.Type("json", "utf-8")
require.Equal(t, "application/json; charset=utf-8", string(c.Response().Header.Peek("Content-Type")))
c.Type(".html")
require.Equal(t, "text/html; charset=utf-8", string(c.Response().Header.Peek("Content-Type")))
c.Type("html", "utf-8")
require.Equal(t, "text/html; charset=utf-8", string(c.Response().Header.Peek("Content-Type")))
// Test other text types get UTF-8 by default
c.Type("txt")
require.Equal(t, "text/plain; charset=utf-8", string(c.Response().Header.Peek("Content-Type")))
c.Type("css")
require.Equal(t, "text/css; charset=utf-8", string(c.Response().Header.Peek("Content-Type")))
c.Type("js")
require.Equal(t, "text/javascript; charset=utf-8", string(c.Response().Header.Peek("Content-Type")))
c.Type("xml")
require.Equal(t, "application/xml; charset=utf-8", string(c.Response().Header.Peek("Content-Type")))
// Test binary types don't get charset
c.Type("png")
require.Equal(t, "image/png", string(c.Response().Header.Peek("Content-Type")))
c.Type("pdf")
require.Equal(t, "application/pdf", string(c.Response().Header.Peek("Content-Type")))
// Test custom charset override
c.Type("html", "iso-8859-1")
require.Equal(t, "text/html; charset=iso-8859-1", string(c.Response().Header.Peek("Content-Type")))
}
// go test -run Test_shouldIncludeCharset
func Test_shouldIncludeCharset(t *testing.T) {
t.Parallel()
// Test text/* types - should include charset
require.True(t, shouldIncludeCharset("text/html"))
require.True(t, shouldIncludeCharset("text/plain"))
require.True(t, shouldIncludeCharset("text/css"))
require.True(t, shouldIncludeCharset("text/javascript"))
require.True(t, shouldIncludeCharset("text/xml"))
// Test explicit application types - should include charset
require.True(t, shouldIncludeCharset("application/json"))
require.True(t, shouldIncludeCharset("application/javascript"))
require.True(t, shouldIncludeCharset("application/xml"))
// Test +json suffixes - should include charset
require.True(t, shouldIncludeCharset("application/problem+json"))
require.True(t, shouldIncludeCharset("application/vnd.api+json"))
require.True(t, shouldIncludeCharset("application/hal+json"))
require.True(t, shouldIncludeCharset("application/merge-patch+json"))
// Test +xml suffixes - should include charset
require.True(t, shouldIncludeCharset("application/soap+xml"))
require.True(t, shouldIncludeCharset("application/xhtml+xml"))
require.True(t, shouldIncludeCharset("application/atom+xml"))
require.True(t, shouldIncludeCharset("application/rss+xml"))
// Test binary types - should NOT include charset
require.False(t, shouldIncludeCharset("image/png"))
require.False(t, shouldIncludeCharset("image/jpeg"))
require.False(t, shouldIncludeCharset("application/pdf"))
require.False(t, shouldIncludeCharset("application/octet-stream"))
require.False(t, shouldIncludeCharset("video/mp4"))
require.False(t, shouldIncludeCharset("audio/mpeg"))
// Test other application types - should NOT include charset
require.False(t, shouldIncludeCharset("application/cbor"))
require.False(t, shouldIncludeCharset("application/x-www-form-urlencoded"))
require.False(t, shouldIncludeCharset("application/vnd.msgpack"))
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Type -benchmem -count=4
func Benchmark_Ctx_Type(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
b.ReportAllocs()
for b.Loop() {
c.Type(".json")
c.Type("json")
}
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Type_Charset -benchmem -count=4
func Benchmark_Ctx_Type_Charset(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
for b.Loop() {
c.Type(".json", "utf-8")
c.Type("json", "utf-8")
}
}
// go test -run Test_Ctx_Vary
func Test_Ctx_Vary(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Vary("Origin")
c.Vary("User-Agent")
c.Vary("Accept-Encoding", "Accept")
require.Equal(t, "Origin, User-Agent, Accept-Encoding, Accept", string(c.Response().Header.Peek("Vary")))
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Vary -benchmem -count=4
func Benchmark_Ctx_Vary(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
for b.Loop() {
c.Vary("Origin", "User-Agent")
}
}
// go test -run Test_Ctx_Write
func Test_Ctx_Write(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
_, err := c.WriteString("Hello, ")
require.NoError(t, err)
_, err = c.WriteString("World!")
require.NoError(t, err)
require.Equal(t, "Hello, World!", string(c.Response().Body()))
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Write -benchmem -count=4
func Benchmark_Ctx_Write(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
byt := []byte("Hello, World!")
b.ReportAllocs()
var err error
for b.Loop() {
_, err = c.Write(byt)
}
require.NoError(b, err)
}
// go test -run Test_Ctx_Writef
func Test_Ctx_Writef(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
world := "World!"
_, err := c.Writef("Hello, %s", world)
require.NoError(t, err)
require.Equal(t, "Hello, World!", string(c.Response().Body()))
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Writef -benchmem -count=4
func Benchmark_Ctx_Writef(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
world := "World!"
b.ReportAllocs()
var err error
for b.Loop() {
_, err = c.Writef("Hello, %s", world)
}
require.NoError(b, err)
}
// go test -run Test_Ctx_WriteString
func Test_Ctx_WriteString(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
_, err := c.WriteString("Hello, ")
require.NoError(t, err)
_, err = c.WriteString("World!")
require.NoError(t, err)
require.Equal(t, "Hello, World!", string(c.Response().Body()))
}
// go test -run Test_Ctx_XHR
func Test_Ctx_XHR(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set(HeaderXRequestedWith, "XMLHttpRequest")
require.True(t, c.XHR())
}
// go test -run=^$ -bench=Benchmark_Ctx_XHR -benchmem -count=4
func Benchmark_Ctx_XHR(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set(HeaderXRequestedWith, "XMLHttpRequest")
var equal bool
b.ReportAllocs()
for b.Loop() {
equal = c.XHR()
}
require.True(b, equal)
}
// go test -v -run=^$ -bench=Benchmark_Ctx_SendString_B -benchmem -count=4
func Benchmark_Ctx_SendString_B(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
body := "Hello, world!"
b.ReportAllocs()
var err error
for b.Loop() {
err = c.SendString(body)
}
require.NoError(b, err)
require.Equal(b, []byte("Hello, world!"), c.Response().Body())
}
// go test -run Test_Ctx_Queries -v
func Test_Ctx_Queries(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetBody([]byte(``))
c.Request().Header.SetContentType("")
c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball,football&favouriteDrinks=milo,coke,pepsi&alloc=&no=1&field1=value1&field1=value2&field2=value3&list_a=1&list_a=2&list_a=3&list_b[]=1&list_b[]=2&list_b[]=3&list_c=1,2,3")
queries := c.Queries()
require.Equal(t, "1", queries["id"])
require.Equal(t, "tom", queries["name"])
require.Equal(t, "basketball,football", queries["hobby"])
require.Equal(t, "milo,coke,pepsi", queries["favouriteDrinks"])
require.Empty(t, queries["alloc"])
require.Equal(t, "1", queries["no"])
require.Equal(t, "value2", queries["field1"])
require.Equal(t, "value3", queries["field2"])
require.Equal(t, "3", queries["list_a"])
require.Equal(t, "3", queries["list_b[]"])
require.Equal(t, "1,2,3", queries["list_c"])
c.Request().URI().SetQueryString("filters.author.name=John&filters.category.name=Technology&filters[customer][name]=Alice&filters[status]=pending")
queries = c.Queries()
require.Equal(t, "John", queries["filters.author.name"])
require.Equal(t, "Technology", queries["filters.category.name"])
require.Equal(t, "Alice", queries["filters[customer][name]"])
require.Equal(t, "pending", queries["filters[status]"])
c.Request().URI().SetQueryString("tags=apple,orange,banana&filters[tags]=apple,orange,banana&filters[category][name]=fruits&filters.tags=apple,orange,banana&filters.category.name=fruits")
queries = c.Req().Queries()
require.Equal(t, "apple,orange,banana", queries["tags"])
require.Equal(t, "apple,orange,banana", queries["filters[tags]"])
require.Equal(t, "fruits", queries["filters[category][name]"])
require.Equal(t, "apple,orange,banana", queries["filters.tags"])
require.Equal(t, "fruits", queries["filters.category.name"])
c.Request().URI().SetQueryString("filters[tags][0]=apple&filters[tags][1]=orange&filters[tags][2]=banana&filters[category][name]=fruits")
queries = c.Queries()
require.Equal(t, "apple", queries["filters[tags][0]"])
require.Equal(t, "orange", queries["filters[tags][1]"])
require.Equal(t, "banana", queries["filters[tags][2]"])
require.Equal(t, "fruits", queries["filters[category][name]"])
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Queries -benchmem -count=4
func Benchmark_Ctx_Queries(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
b.ReportAllocs()
c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball,football&favouriteDrinks=milo,coke,pepsi&alloc=&no=1")
var queries map[string]string
for b.Loop() {
queries = c.Queries()
}
require.Equal(b, "1", queries["id"])
require.Equal(b, "tom", queries["name"])
require.Equal(b, "basketball,football", queries["hobby"])
require.Equal(b, "milo,coke,pepsi", queries["favouriteDrinks"])
require.Empty(b, queries["alloc"])
require.Equal(b, "1", queries["no"])
}
// go test -run Test_Ctx_BodyStreamWriter
func Test_Ctx_BodyStreamWriter(t *testing.T) {
t.Parallel()
ctx := &fasthttp.RequestCtx{}
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
fmt.Fprintf(w, "body writer line 1\n")
if err := w.Flush(); err != nil {
t.Errorf("unexpected error: %s", err)
}
fmt.Fprintf(w, "body writer line 2\n")
})
require.True(t, ctx.IsBodyStream())
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
var resp fasthttp.Response
require.NoError(t, resp.Read(br))
body := string(resp.Body())
expectedBody := "body writer line 1\nbody writer line 2\n"
require.Equal(t, expectedBody, body)
}
// go test -v -run=^$ -bench=Benchmark_Ctx_BodyStreamWriter -benchmem -count=4
func Benchmark_Ctx_BodyStreamWriter(b *testing.B) {
ctx := &fasthttp.RequestCtx{}
user := []byte(`{"name":"john"}`)
b.ReportAllocs()
var err error
for b.Loop() {
ctx.ResetBody()
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
for range 10 {
_, err = w.Write(user)
if flushErr := w.Flush(); flushErr != nil {
return
}
}
})
}
require.NoError(b, err)
}
func Test_Ctx_String(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
require.Equal(t, "#0000000000000000 - 0.0.0.0:0 <-> 0.0.0.0:0 - GET http:///", c.String())
}
// go test -v -run=^$ -bench=Benchmark_Ctx_String -benchmem -count=4
func Benchmark_Ctx_String(b *testing.B) {
var str string
app := New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
b.ReportAllocs()
for b.Loop() {
str = ctx.String()
}
require.Equal(b, "#0000000000000000 - 0.0.0.0:0 <-> 0.0.0.0:0 - GET http:///", str)
}
// go test -run Test_Ctx_IsFromLocal_X_Forwarded
func Test_Ctx_IsFromLocal_X_Forwarded(t *testing.T) {
t.Parallel()
// Test unset X-Forwarded-For header.
{
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
// fasthttp returns "0.0.0.0" as IP as there is no remote address.
require.Equal(t, "0.0.0.0", c.IP())
require.False(t, c.IsFromLocal())
}
// Test when setting X-Forwarded-For header to localhost "127.0.0.1"
{
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1")
defer app.ReleaseCtx(c)
require.False(t, c.IsFromLocal())
}
// Test when setting X-Forwarded-For header to localhost "::1"
{
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set(HeaderXForwardedFor, "::1")
defer app.ReleaseCtx(c)
require.False(t, c.IsFromLocal())
}
// Test when setting X-Forwarded-For to full localhost IPv6 address "0:0:0:0:0:0:0:1"
{
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set(HeaderXForwardedFor, "0:0:0:0:0:0:0:1")
defer app.ReleaseCtx(c)
require.False(t, c.IsFromLocal())
}
// Test for a random IP address.
{
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set(HeaderXForwardedFor, "93.46.8.90")
require.False(t, c.Req().IsFromLocal())
}
}
// go test -run Test_Ctx_IsFromLocal_RemoteAddr
func Test_Ctx_IsFromLocal_RemoteAddr(t *testing.T) {
t.Parallel()
localIPv4 := net.Addr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1")})
localIPv6 := net.Addr(&net.TCPAddr{IP: net.ParseIP("::1")})
localIPv6long := net.Addr(&net.TCPAddr{IP: net.ParseIP("0:0:0:0:0:0:0:1")})
zeroIPv4 := net.Addr(&net.TCPAddr{IP: net.IPv4zero})
someIPv4 := net.Addr(&net.TCPAddr{IP: net.ParseIP("93.46.8.90")})
someIPv6 := net.Addr(&net.TCPAddr{IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")})
// Test for the case fasthttp remoteAddr is set to "127.0.0.1".
{
app := New()
fastCtx := &fasthttp.RequestCtx{}
fastCtx.SetRemoteAddr(localIPv4)
c := app.AcquireCtx(fastCtx)
defer app.ReleaseCtx(c)
require.Equal(t, "127.0.0.1", c.IP())
require.True(t, c.IsFromLocal())
}
// Test for the case fasthttp remoteAddr is set to "::1".
{
app := New()
fastCtx := &fasthttp.RequestCtx{}
fastCtx.SetRemoteAddr(localIPv6)
c := app.AcquireCtx(fastCtx)
defer app.ReleaseCtx(c)
require.Equal(t, "::1", c.Req().IP())
require.True(t, c.Req().IsFromLocal())
}
// Test for the case fasthttp remoteAddr is set to "0:0:0:0:0:0:0:1".
{
app := New()
fastCtx := &fasthttp.RequestCtx{}
fastCtx.SetRemoteAddr(localIPv6long)
c := app.AcquireCtx(fastCtx)
defer app.ReleaseCtx(c)
// fasthttp should return "::1" for "0:0:0:0:0:0:0:1".
// otherwise IsFromLocal() will break.
require.Equal(t, "::1", c.IP())
require.True(t, c.IsFromLocal())
}
// Test for the case fasthttp remoteAddr is set to "0.0.0.0".
{
app := New()
fastCtx := &fasthttp.RequestCtx{}
fastCtx.SetRemoteAddr(zeroIPv4)
c := app.AcquireCtx(fastCtx)
defer app.ReleaseCtx(c)
require.Equal(t, "0.0.0.0", c.IP())
require.False(t, c.IsFromLocal())
}
// Test for the case fasthttp remoteAddr is set to "93.46.8.90".
{
app := New()
fastCtx := &fasthttp.RequestCtx{}
fastCtx.SetRemoteAddr(someIPv4)
c := app.AcquireCtx(fastCtx)
defer app.ReleaseCtx(c)
require.Equal(t, "93.46.8.90", c.IP())
require.False(t, c.IsFromLocal())
}
// Test for the case fasthttp remoteAddr is set to "2001:0db8:85a3:0000:0000:8a2e:0370:7334".
{
app := New()
fastCtx := &fasthttp.RequestCtx{}
fastCtx.SetRemoteAddr(someIPv6)
c := app.AcquireCtx(fastCtx)
defer app.ReleaseCtx(c)
require.Equal(t, "2001:db8:85a3::8a2e:370:7334", c.IP())
require.False(t, c.IsFromLocal())
}
// Test for the case fasthttp remoteAddr is set to a Unix socket.
// Unix sockets are inherently local - only processes on the same host can connect.
{
app := New()
fastCtx := &fasthttp.RequestCtx{}
unixAddr := &net.UnixAddr{Name: "/tmp/fiber.sock", Net: "unix"}
fastCtx.SetRemoteAddr(unixAddr)
c := app.AcquireCtx(fastCtx)
defer app.ReleaseCtx(c)
require.True(t, c.IsFromLocal())
}
}
// go test -run Test_Ctx_extractIPsFromHeader -v
func Test_Ctx_extractIPsFromHeader(t *testing.T) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set("x-forwarded-for", "1.1.1.1,8.8.8.8 , /n, \n,1.1, a.c, 6.,6., , a,,42.118.81.169,10.0.137.108")
ips := c.IPs()
res := ips[len(ips)-2]
require.Equal(t, "42.118.81.169", res)
}
// go test -run Test_Ctx_extractIPsFromHeader -v
func Test_Ctx_extractIPsFromHeader_EnableValidateIp(t *testing.T) {
app := New()
app.config.EnableIPValidation = true
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set("x-forwarded-for", "1.1.1.1,8.8.8.8 , /n, \n,1.1, a.c, 6.,6., , a,,42.118.81.169,10.0.137.108")
ips := c.IPs()
res := ips[len(ips)-2]
require.Equal(t, "42.118.81.169", res)
}
// go test -run Test_Ctx_GetRespHeaders
func Test_Ctx_GetRespHeaders(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Set("test", "Hello, World 👋!")
c.Set("foo", "bar")
c.Response().Header.Set("multi", "one")
c.Response().Header.Add("multi", "two")
c.Response().Header.Set(HeaderContentType, "application/json")
require.Equal(t, map[string][]string{
"Content-Type": {"application/json"},
"Foo": {"bar"},
"Multi": {"one", "two"},
"Test": {"Hello, World 👋!"},
}, c.GetRespHeaders())
require.Equal(t, map[string][]string{
"Content-Type": {"application/json"},
"Foo": {"bar"},
"Multi": {"one", "two"},
"Test": {"Hello, World 👋!"},
}, c.Res().GetHeaders())
}
func Benchmark_Ctx_GetRespHeaders(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Response().Header.Set("test", "Hello, World 👋!")
c.Response().Header.Set("foo", "bar")
c.Response().Header.Set(HeaderContentType, "application/json")
b.ReportAllocs()
var headers map[string][]string
for b.Loop() {
headers = c.GetRespHeaders()
}
require.Equal(b, map[string][]string{
"Content-Type": {"application/json"},
"Foo": {"bar"},
"Test": {"Hello, World 👋!"},
}, headers)
}
// go test -run Test_Ctx_GetReqHeaders
func Test_Ctx_GetReqHeaders(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set("test", "Hello, World 👋!")
c.Request().Header.Set("foo", "bar")
c.Request().Header.Set("multi", "one")
c.Request().Header.Add("multi", "two")
c.Request().Header.Set(HeaderContentType, "application/json")
require.Equal(t, map[string][]string{
"Content-Type": {"application/json"},
"Foo": {"bar"},
"Test": {"Hello, World 👋!"},
"Multi": {"one", "two"},
}, c.GetReqHeaders())
require.Equal(t, map[string][]string{
"Content-Type": {"application/json"},
"Foo": {"bar"},
"Test": {"Hello, World 👋!"},
"Multi": {"one", "two"},
}, c.GetHeaders())
}
func Test_Ctx_Set_SanitizeHeaderValue(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Set("X-Test", "foo\r\nbar: bad")
headerVal := string(c.Response().Header.Peek("X-Test"))
require.Equal(t, "foo bar: bad", headerVal)
}
func Benchmark_Ctx_GetReqHeaders(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set("test", "Hello, World 👋!")
c.Request().Header.Set("foo", "bar")
c.Request().Header.Set(HeaderContentType, "application/json")
b.ReportAllocs()
var headers map[string][]string
for b.Loop() {
headers = c.GetReqHeaders()
}
require.Equal(b, map[string][]string{
"Content-Type": {"application/json"},
"Foo": {"bar"},
"Test": {"Hello, World 👋!"},
}, headers)
}
// go test -run Test_Ctx_Drop -v
func Test_Ctx_Drop(t *testing.T) {
t.Parallel()
app := New()
// Handler that calls Drop
app.Get("/block-me", func(c Ctx) error {
return c.Drop()
})
// Additional handler that just calls return
app.Get("/no-response", func(_ Ctx) error {
return nil
})
// Test the Drop method
resp, err := app.Test(httptest.NewRequest(MethodGet, "/block-me", http.NoBody))
require.ErrorIs(t, err, ErrTestGotEmptyResponse)
require.Nil(t, resp)
// Test the no-response handler
resp, err = app.Test(httptest.NewRequest(MethodGet, "/no-response", http.NoBody))
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, StatusOK, resp.StatusCode)
require.Equal(t, "0", resp.Header.Get("Content-Length"))
}
// go test -run Test_Ctx_DropWithMiddleware -v
func Test_Ctx_DropWithMiddleware(t *testing.T) {
t.Parallel()
app := New()
// Middleware that calls Drop
app.Use(func(c Ctx) error {
err := c.Next()
c.Set("X-Test", "test")
return err
})
// Handler that calls Drop
app.Get("/block-me", func(c Ctx) error {
return c.Drop()
})
// Test the Drop method
resp, err := app.Test(httptest.NewRequest(MethodGet, "/block-me", http.NoBody))
require.ErrorIs(t, err, ErrTestGotEmptyResponse)
require.Nil(t, resp)
}
// go test -run Test_Ctx_End
func Test_Ctx_End(t *testing.T) {
app := New()
app.Get("/", func(c Ctx) error {
c.SendString("Hello, World!") //nolint:errcheck // unnecessary to check error
return c.End()
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "io.ReadAll(resp.Body)")
require.Equal(t, "Hello, World!", string(body))
}
// go test -run Test_Ctx_End_after_timeout
func Test_Ctx_End_after_timeout(t *testing.T) {
app := New()
// Early flushing handler
app.Get("/", func(c Ctx) error {
time.Sleep(2 * time.Second)
return c.End()
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody))
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
require.Nil(t, resp)
}
// go test -run Test_Ctx_End_with_drop_middleware
func Test_Ctx_End_with_drop_middleware(t *testing.T) {
app := New()
// Middleware that will drop connections
// that persist after c.Next()
app.Use(func(c Ctx) error {
c.Next() //nolint:errcheck // unnecessary to check error
return c.Drop()
})
// Early flushing handler
app.Get("/", func(c Ctx) error {
c.SendStatus(StatusOK) //nolint:errcheck // unnecessary to check error
return c.End()
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, StatusOK, resp.StatusCode)
}
// go test -run Test_Ctx_End_after_drop
func Test_Ctx_End_after_drop(t *testing.T) {
app := New()
// Middleware that ends the request
// after c.Next()
app.Use(func(c Ctx) error {
c.Next() //nolint:errcheck // unnecessary to check error
return c.End()
})
// Early flushing handler
app.Get("/", func(c Ctx) error {
return c.Drop()
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody))
require.ErrorIs(t, err, ErrTestGotEmptyResponse)
require.Nil(t, resp)
}
// go test -run Test_Ctx_OverrideParam
func Test_Ctx_OverrideParam(t *testing.T) {
t.Parallel()
t.Run("route_params", func(t *testing.T) {
// a basic request to check if OverrideParam functions correctly on different scenarios
// - Does it change an existing param (it should)
// - Does it ignore a non-existing param (it should)
t.Parallel()
app := New()
app.Get("/user/:name/:id", func(c Ctx) error {
c.OverrideParam("name", "overridden")
c.OverrideParam("nonexistent", "ignored")
require.Equal(t, "overridden", c.Params("name"))
require.Equal(t, "123", c.Params("id"))
require.Empty(t, c.Params("nonexistent"))
require.Equal(t, []string{"name", "id"}, c.Route().Params)
return c.JSON(map[string]any{
"name": c.Params("name"),
"id": c.Params("id"),
"all": c.Route().Params,
})
})
req, err := http.NewRequest(http.MethodGet, "/user/original/123", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
defer func() { require.NoError(t, resp.Body.Close()) }()
require.Equal(t, StatusOK, resp.StatusCode)
})
t.Run("plus_wildcard_params", func(t *testing.T) {
t.Parallel()
app := New()
app.Get("/files+/+",
func(c Ctx) error {
c.OverrideParam("+", "changed")
c.OverrideParam("+2", "changed2")
require.Equal(t, "changed", c.Params("+"))
require.Equal(t, "changed2", c.Params("+2"))
return nil
},
)
req, err := http.NewRequest(http.MethodGet, "/filesoriginal/original2", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
defer func() { require.NoError(t, resp.Body.Close()) }()
require.Equal(t, StatusOK, resp.StatusCode)
})
t.Run("wildcard_params", func(t *testing.T) {
t.Parallel()
app := New()
app.Get("/files/*", func(c Ctx) error {
c.OverrideParam("*", "changed")
require.Equal(t, "changed", c.Params("*"))
return nil
})
req, err := http.NewRequest(http.MethodGet, "/files/testing", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
defer func() { require.NoError(t, resp.Body.Close()) }()
require.Equal(t, StatusOK, resp.StatusCode)
})
t.Run("multi_wildcard_params", func(t *testing.T) {
t.Parallel()
app := New()
app.Get("/files/*/*", func(c Ctx) error {
c.OverrideParam("*", "changed")
c.OverrideParam("*2", "changed2")
require.Equal(t, "changed", c.Params("*"))
require.Equal(t, "changed2", c.Params("*2"))
return nil
})
req, err := http.NewRequest(http.MethodGet, "/files/testing/testing", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
defer func() { require.NoError(t, resp.Body.Close()) }()
require.Equal(t, StatusOK, resp.StatusCode)
})
t.Run("case_sensitive", func(t *testing.T) {
t.Parallel()
// Ensure OverrideParam respects the CaseSensitive configuration
app := New(Config{
CaseSensitive: true,
})
app.Get("/user/:name", func(c Ctx) error {
c.OverrideParam("name", "overridden")
require.Equal(t, "overridden", c.Params("name"))
require.Empty(t, c.Params("NAME"))
return c.SendStatus(StatusOK)
})
req, err := http.NewRequest(http.MethodGet, "/user/original", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
defer func() { require.NoError(t, resp.Body.Close()) }()
require.Equal(t, StatusOK, resp.StatusCode)
})
t.Run("case_insensitive", func(t *testing.T) {
t.Parallel()
// CaseInsensitive mode (default)
app := New(Config{
CaseSensitive: false,
})
app.Get("/user/:name", func(c Ctx) error {
c.OverrideParam("NAME", "overridden")
require.Equal(t, "overridden", c.Params("name"))
require.Equal(t, "overridden", c.Params("NAME"))
return c.SendStatus(StatusOK)
})
req, err := http.NewRequest(http.MethodGet, "/user/original", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
defer func() { require.NoError(t, resp.Body.Close()) }()
require.Equal(t, StatusOK, resp.StatusCode)
})
t.Run("nil_router", func(t *testing.T) {
t.Parallel()
// Ensure OverrideParam handles nil route context gracefully
app := New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
c, ok := ctx.(*DefaultCtx)
require.True(t, ok)
defer app.ReleaseCtx(c)
c.route = nil
c.OverrideParam("test", "value") // Should not change
require.Empty(t, c.Params("test"))
})
}
func Test_Ctx_AbandonSkipsReleaseCtx(t *testing.T) {
t.Parallel()
app := New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // controlled test setup
ctx.route = &Route{}
t.Cleanup(func() {
ctx.ForceRelease()
})
require.False(t, ctx.IsAbandoned())
ctx.Abandon()
require.True(t, ctx.IsAbandoned())
app.ReleaseCtx(ctx)
require.True(t, ctx.IsAbandoned(), "ReleaseCtx must not pool abandoned contexts")
require.NotNil(t, ctx.fasthttp, "ReleaseCtx should not reset fasthttp on abandoned ctx")
require.NotNil(t, ctx.route, "ReleaseCtx should not reset route on abandoned ctx")
}
func Test_Ctx_ForceReleaseClearsAbandon(t *testing.T) {
t.Parallel()
app := New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // controlled test setup
ctx.route = &Route{}
ctx.Abandon()
ctx.ForceRelease()
require.False(t, ctx.IsAbandoned(), "ForceRelease should clear abandon flag")
require.Nil(t, ctx.fasthttp, "ForceRelease should release fasthttp reference")
require.Nil(t, ctx.route, "ForceRelease should reset route before pooling")
}
// go test -v -run=^$ -bench=Benchmark_Ctx_IsProxyTrusted -benchmem -count=4
func Benchmark_Ctx_IsProxyTrusted(b *testing.B) {
// Scenario without trusted proxy check
b.Run("NoProxyCheck", func(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com:8080/test")
b.ReportAllocs()
for b.Loop() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
// Scenario without trusted proxy check in parallel
b.Run("NoProxyCheckParallel", func(b *testing.B) {
app := New()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com:8080/test")
for pb.Next() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
})
// Scenario with trusted proxy check simple
b.Run("WithProxyCheckSimple", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
b.ReportAllocs()
for b.Loop() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
// Scenario with trusted proxy check simple in parallel
b.Run("WithProxyCheckSimpleParallel", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
for pb.Next() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
})
// Scenario with trusted proxy check
b.Run("WithProxyCheck", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Proxies: []string{"0.0.0.0"},
},
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
b.ReportAllocs()
for b.Loop() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
// Scenario with trusted proxy check in parallel
b.Run("WithProxyCheckParallel", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Proxies: []string{"0.0.0.0"},
},
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
for pb.Next() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
})
// Scenario with trusted proxy check allow private
b.Run("WithProxyCheckAllowPrivate", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Private: true,
},
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
b.ReportAllocs()
for b.Loop() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
// Scenario with trusted proxy check allow private in parallel
b.Run("WithProxyCheckAllowPrivateParallel", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Private: true,
},
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
for pb.Next() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
})
// Scenario with trusted proxy check allow private as subnets
b.Run("WithProxyCheckAllowPrivateAsSubnets", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Proxies: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "fc00::/7"},
},
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
b.ReportAllocs()
for b.Loop() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
// Scenario with trusted proxy check allow private as subnets in parallel
b.Run("WithProxyCheckAllowPrivateAsSubnetsParallel", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Proxies: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "fc00::/7"},
},
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
for pb.Next() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
})
// Scenario with trusted proxy check allow private, loopback, and link-local
b.Run("WithProxyCheckAllowAll", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Private: true,
Loopback: true,
LinkLocal: true,
},
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
b.ReportAllocs()
for b.Loop() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
// Scenario with trusted proxy check allow private, loopback, and link-local in parallel
b.Run("WithProxyCheckAllowAllParallel", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Private: true,
Loopback: true,
LinkLocal: true,
},
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
for pb.Next() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
})
// Scenario with trusted proxy check allow private, loopback, and link-local as subnets
b.Run("WithProxyCheckAllowAllowAllAsSubnets", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Proxies: []string{
// Link-local
"169.254.0.0/16",
"fe80::/10",
// Loopback
"127.0.0.0/8",
"::1/128",
// Private
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"fc00::/7",
},
},
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
// Scenario with trusted proxy check allow private, loopback, and link-local as subnets in parallel
b.Run("WithProxyCheckAllowAllowAllAsSubnetsParallel", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Proxies: []string{
// Link-local
"169.254.0.0/16",
"fe80::/10",
// Loopback
"127.0.0.0/8",
"::1/128",
// Private
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"fc00::/7",
},
},
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
for pb.Next() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
})
// Scenario with trusted proxy check with subnet
b.Run("WithProxyCheckSubnet", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Proxies: []string{"0.0.0.0/8"},
},
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
b.ReportAllocs()
for b.Loop() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
// Scenario with trusted proxy check with subnet in parallel
b.Run("WithProxyCheckParallelSubnet", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Proxies: []string{"0.0.0.0/8"},
},
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
for pb.Next() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
})
// Scenario with trusted proxy check with multiple subnet
b.Run("WithProxyCheckMultipleSubnet", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Proxies: []string{"192.168.0.0/24", "10.0.0.0/16", "0.0.0.0/8"},
},
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
b.ReportAllocs()
for b.Loop() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
// Scenario with trusted proxy check with multiple subnet in parallel
b.Run("WithProxyCheckParallelMultipleSubnet", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Proxies: []string{"192.168.0.0/24", "10.0.0.0/16", "0.0.0.0/8"},
},
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
for pb.Next() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
})
// Scenario with trusted proxy check with all subnets
b.Run("WithProxyCheckAllSubnets", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Proxies: []string{
"127.0.0.0/8", // Loopback addresses
"169.254.0.0/16", // Link-Local addresses
"fe80::/10", // Link-Local addresses
"192.168.0.0/16", // Private Network addresses
"172.16.0.0/12", // Private Network addresses
"10.0.0.0/8", // Private Network addresses
"fc00::/7", // Unique Local addresses
"173.245.48.0/20", // My custom range
"0.0.0.0/8", // All IPv4 addresses
},
},
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/test")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
b.ReportAllocs()
for b.Loop() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
// Scenario with trusted proxy check with all subnets in parallel
b.Run("WithProxyCheckParallelAllSubnets", func(b *testing.B) {
app := New(Config{
TrustProxy: true,
TrustProxyConfig: TrustProxyConfig{
Proxies: []string{
"127.0.0.0/8", // Loopback addresses
"169.254.0.0/16", // Link-Local addresses
"fe80::/10", // Link-Local addresses
"192.168.0.0/16", // Private Network addresses
"172.16.0.0/12", // Private Network addresses
"10.0.0.0/8", // Private Network addresses
"fc00::/7", // Unique Local addresses
"173.245.48.0/20", // My custom range
"0.0.0.0/8", // All IPv4 addresses
},
},
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com/")
c.Request().Header.Set(HeaderXForwardedHost, "google1.com")
for pb.Next() {
c.IsProxyTrusted()
}
app.ReleaseCtx(c)
})
})
}
func Benchmark_Ctx_IsFromLocalhost(b *testing.B) {
// Scenario without localhost check
b.Run("Non_Localhost", func(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://google.com:8080/test")
b.ReportAllocs()
for b.Loop() {
c.IsFromLocal()
}
app.ReleaseCtx(c)
})
// Scenario with localhost check
b.Run("Localhost", func(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().SetRequestURI("http://localhost:8080/test")
b.ReportAllocs()
for b.Loop() {
c.IsFromLocal()
}
app.ReleaseCtx(c)
})
}
// go test -v -run=^$ -bench=Benchmark_Ctx_OverrideParam -benchmem -count=4
func Benchmark_Ctx_OverrideParam(b *testing.B) {
app := New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
c, ok := ctx.(*DefaultCtx)
if !ok {
b.Fatal("AcquireCtx did not return *DefaultCtx")
}
defer app.ReleaseCtx(c)
c.values = [maxParams]string{"original", "12345"}
c.route = &Route{Params: []string{"name", "id"}}
c.setMatched(true)
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
c.OverrideParam("name", "changed")
}
}
================================================
FILE: docs/addon/_category_.json
================================================
{
"label": "\uD83D\uDD0C Addon",
"position": 5,
"collapsed": true,
"link": {
"type": "generated-index",
"description": "Addon is an additional useful package that can be used in Fiber."
}
}
================================================
FILE: docs/addon/retry.md
================================================
---
id: retry
---
# Retry Addon
The Retry addon for [Fiber](https://github.com/gofiber/fiber) retries failed network operations using exponential
backoff with jitter. It repeatedly invokes a function until it succeeds or the maximum number of attempts is
exhausted. Jitter at each step breaks client synchronization and helps avoid collisions. If all attempts fail, the
addon returns an error.
## Table of Contents
- [Signatures](#signatures)
- [Examples](#examples)
- [Default Config](#default-config)
- [Custom Config](#custom-config)
- [Config](#config)
- [Default Config Example](#default-config-example)
## Signatures
```go
func NewExponentialBackoff(config ...retry.Config) *retry.ExponentialBackoff
```
## Examples
```go
package main
import (
"fmt"
"github.com/gofiber/fiber/v3/addon/retry"
"github.com/gofiber/fiber/v3/client"
)
func main() {
expBackoff := retry.NewExponentialBackoff(retry.Config{})
// Local variables used inside Retry
var resp *client.Response
var err error
// Retry a network request and return an error to signal another attempt
err = expBackoff.Retry(func() error {
client := client.New()
resp, err = client.Get("https://gofiber.io")
if err != nil {
return fmt.Errorf("GET gofiber.io failed: %w", err)
}
if resp.StatusCode() != 200 {
return fmt.Errorf("GET gofiber.io did not return 200 OK")
}
return nil
})
// If all retries failed, panic
if err != nil {
panic(err)
}
fmt.Printf("GET gofiber.io succeeded with status code %d\n", resp.StatusCode())
}
```
## Default Config
```go
retry.NewExponentialBackoff()
```
## Custom Config
```go
retry.NewExponentialBackoff(retry.Config{
InitialInterval: 2 * time.Second,
MaxBackoffTime: 64 * time.Second,
Multiplier: 2.0,
MaxRetryCount: 15,
})
```
## Config
```go
// Config defines the config for addon.
type Config struct {
// InitialInterval defines the initial time interval for backoff algorithm.
//
// Optional. Default: 1 * time.Second
InitialInterval time.Duration
// MaxBackoffTime defines maximum time duration for backoff algorithm. When
// the algorithm is reached this time, rest of the retries will be maximum
// 32 seconds.
//
// Optional. Default: 32 * time.Second
MaxBackoffTime time.Duration
// Multiplier defines multiplier number of the backoff algorithm.
//
// Optional. Default: 2.0
Multiplier float64
// MaxRetryCount defines maximum retry count for the backoff algorithm.
//
// Optional. Default: 10
MaxRetryCount int
}
```
## Default Config Example
```go
// DefaultConfig is the default config for retry.
var DefaultConfig = Config{
InitialInterval: 1 * time.Second,
MaxBackoffTime: 32 * time.Second,
Multiplier: 2.0,
MaxRetryCount: 10,
currentInterval: 1 * time.Second,
}
```
================================================
FILE: docs/api/_category_.json
================================================
{
"label": "\uD83D\uDEE0\uFE0F API",
"position": 3,
"link": {
"type": "generated-index",
"description": "API documentation for Fiber."
}
}
================================================
FILE: docs/api/app.md
================================================
---
id: app
title: 🚀 App
description: The `App` type represents your Fiber application.
sidebar_position: 2
---
import Reference from '@site/src/components/reference';
## Helpers
### GetString
Returns `s` unchanged when [`Immutable`](./fiber.md#immutable) is disabled or `s` resides in read-only memory. Otherwise, it returns a detached copy using `strings.Clone`.
```go title="Signature"
func (app *App) GetString(s string) string
```
### GetBytes
Returns `b` unchanged when [`Immutable`](./fiber.md#immutable) is disabled or `b` resides in read-only memory. Otherwise, it returns a detached copy.
```go title="Signature"
func (app *App) GetBytes(b []byte) []byte
```
### ReloadViews
Reloads the configured view engine on demand by calling its `Load` method. Use this helper in development workflows (e.g., file watchers or debug-only routes) to pick up template changes without restarting the server. Returns an error if no view engine is configured or reloading fails.
```go title="Signature"
func (app *App) ReloadViews() error
```
```go title="Example"
app := fiber.New(fiber.Config{Views: engine})
app.Get("/dev/reload", func(c fiber.Ctx) error {
if err := app.ReloadViews(); err != nil {
return err
}
return c.SendString("Templates reloaded")
})
```
## Routing
import RoutingHandler from './../partials/routing/handler.md';
### Route Handlers
### Mounting
Mount another Fiber instance with [`app.Use`](./app.md#use), similar to Express's [`router.use`](https://expressjs.com/en/api.html#router.use).
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
micro := fiber.New()
// Mount the micro app on the "/john" route
app.Use("/john", micro) // GET /john/doe -> 200 OK
micro.Get("/doe", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
log.Fatal(app.Listen(":3000"))
}
```
### MountPath
The `MountPath` property contains one or more path patterns on which a sub-app was mounted.
```go title="Signature"
func (app *App) MountPath() string
```
```go title="Example"
package main
import (
"fmt"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
one := fiber.New()
two := fiber.New()
three := fiber.New()
two.Use("/three", three)
one.Use("/two", two)
app.Use("/one", one)
fmt.Println("Mount paths:")
fmt.Println("one.MountPath():", one.MountPath()) // "/one"
fmt.Println("two.MountPath():", two.MountPath()) // "/one/two"
fmt.Println("three.MountPath():", three.MountPath()) // "/one/two/three"
fmt.Println("app.MountPath():", app.MountPath()) // ""
}
```
:::caution
Mounting order is important for `MountPath`. To get mount paths properly, you should start mounting from the deepest app.
:::
### Group
You can group routes by creating a `*Group` struct.
```go title="Signature"
func (app *App) Group(prefix string, handlers ...any) Router
```
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
api := app.Group("/api", handler) // /api
v1 := api.Group("/v1", handler) // /api/v1
v1.Get("/list", handler) // /api/v1/list
v1.Get("/user", handler) // /api/v1/user
v2 := api.Group("/v2", handler) // /api/v2
v2.Get("/list", handler) // /api/v2/list
v2.Get("/user", handler) // /api/v2/user
log.Fatal(app.Listen(":3000"))
}
func handler(c fiber.Ctx) error {
return c.SendString("Handler response")
}
```
### RouteChain
Returns an instance of a single route, which you can then use to handle HTTP verbs with optional middleware.
Similar to [`Express`](https://expressjs.com/en/api.html#app.route).
```go title="Signature"
func (app *App) RouteChain(path string) Register
```
Click here to see the `Register` interface
```go
type Register interface {
All(handler any, handlers ...any) Register
Get(handler any, handlers ...any) Register
Head(handler any, handlers ...any) Register
Post(handler any, handlers ...any) Register
Put(handler any, handlers ...any) Register
Delete(handler any, handlers ...any) Register
Connect(handler any, handlers ...any) Register
Options(handler any, handlers ...any) Register
Trace(handler any, handlers ...any) Register
Patch(handler any, handlers ...any) Register
Add(methods []string, handler any, handlers ...any) Register
RouteChain(path string) Register
}
```
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
// Use `RouteChain` as a chainable route declaration method
app.RouteChain("/test").Get(func(c fiber.Ctx) error {
return c.SendString("GET /test")
})
app.RouteChain("/events").All(func(c fiber.Ctx) error {
// Runs for all HTTP verbs first
// Think of it as route-specific middleware!
}).
Get(func(c fiber.Ctx) error {
return c.SendString("GET /events")
}).
Post(func(c fiber.Ctx) error {
// Maybe add a new event...
return c.SendString("POST /events")
})
// Combine multiple routes
app.RouteChain("/reports").RouteChain("/daily").Get(func(c fiber.Ctx) error {
return c.SendString("GET /reports/daily")
})
// Use multiple methods
app.RouteChain("/api").Get(func(c fiber.Ctx) error {
return c.SendString("GET /api")
}).Post(func(c fiber.Ctx) error {
return c.SendString("POST /api")
})
log.Fatal(app.Listen(":3000"))
}
```
### Route
Defines routes with a common prefix inside the supplied function. Internally it uses [`Group`](#group) to create a sub-router and accepts an optional name prefix.
```go title="Signature"
func (app *App) Route(prefix string, fn func(router Router), name ...string) Router
```
```go title="Example"
app.Route("/test", func(api fiber.Router) {
api.Get("/foo", handler).Name("foo") // /test/foo (name: test.foo)
api.Get("/bar", handler).Name("bar") // /test/bar (name: test.bar)
}, "test.")
```
### HandlersCount
Returns the number of registered handlers.
```go title="Signature"
func (app *App) HandlersCount() uint32
```
### Stack
Returns the underlying router stack.
```go title="Signature"
func (app *App) Stack() [][]*Route
```
```go title="Example"
package main
import (
"encoding/json"
"log"
"github.com/gofiber/fiber/v3"
)
var handler = func(c fiber.Ctx) error { return nil }
func main() {
app := fiber.New()
app.Get("/john/:age", handler)
app.Post("/register", handler)
data, _ := json.MarshalIndent(app.Stack(), "", " ")
fmt.Println(string(data))
log.Fatal(app.Listen(":3000"))
}
```
Click here to see the result
```json
[
[
{
"method": "GET",
"path": "/john/:age",
"params": [
"age"
]
}
],
[
{
"method": "HEAD",
"path": "/john/:age",
"params": [
"age"
]
}
],
[
{
"method": "POST",
"path": "/register",
"params": null
}
]
]
```
### Name
This method assigns the name to the latest created route.
```go title="Signature"
func (app *App) Name(name string) Router
```
```go title="Example"
package main
import (
"encoding/json"
"fmt"
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
var handler = func(c fiber.Ctx) error { return nil }
app := fiber.New()
app.Get("/", handler)
app.Name("index")
app.Get("/doe", handler).Name("home")
app.Trace("/tracer", handler).Name("tracert")
app.Delete("/delete", handler).Name("delete")
a := app.Group("/a")
a.Name("fd.")
a.Get("/test", handler).Name("test")
data, _ := json.MarshalIndent(app.Stack(), "", " ")
fmt.Println(string(data))
log.Fatal(app.Listen(":3000"))
}
```
Click here to see the result
```json
[
[
{
"method": "GET",
"name": "index",
"path": "/",
"params": null
},
{
"method": "GET",
"name": "home",
"path": "/doe",
"params": null
},
{
"method": "GET",
"name": "fd.test",
"path": "/a/test",
"params": null
}
],
[
{
"method": "HEAD",
"name": "",
"path": "/",
"params": null
},
{
"method": "HEAD",
"name": "",
"path": "/doe",
"params": null
},
{
"method": "HEAD",
"name": "",
"path": "/a/test",
"params": null
}
],
null,
null,
[
{
"method": "DELETE",
"name": "delete",
"path": "/delete",
"params": null
}
],
null,
null,
[
{
"method": "TRACE",
"name": "tracert",
"path": "/tracer",
"params": null
}
],
null
]
```
### GetRoute
This method retrieves a route by its name.
```go title="Signature"
func (app *App) GetRoute(name string) Route
```
```go title="Example"
package main
import (
"encoding/json"
"fmt"
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Get("/", handler).Name("index")
route := app.GetRoute("index")
data, _ := json.MarshalIndent(route, "", " ")
fmt.Println(string(data))
log.Fatal(app.Listen(":3000"))
}
```
Click here to see the result
```json
{
"method": "GET",
"name": "index",
"path": "/",
"params": null
}
```
### GetRoutes
This method retrieves all routes.
```go title="Signature"
func (app *App) GetRoutes(filterUseOption ...bool) []Route
```
When `filterUseOption` is set to `true`, it filters out routes registered by middleware.
```go title="Example"
package main
import (
"encoding/json"
"fmt"
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Post("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
}).Name("index")
routes := app.GetRoutes(true)
data, _ := json.MarshalIndent(routes, "", " ")
fmt.Println(string(data))
log.Fatal(app.Listen(":3000"))
}
```
Click here to see the result
```json
[
{
"method": "POST",
"name": "index",
"path": "/",
"params": null
}
]
```
## Config
`Config` returns the [app config](./fiber.md#config) as a value (read-only).
```go title="Signature"
func (app *App) Config() Config
```
## Handler
`Handler` returns the server handler that can be used to serve custom [`\*fasthttp.RequestCtx`](https://pkg.go.dev/github.com/valyala/fasthttp#RequestCtx) requests.
```go title="Signature"
func (app *App) Handler() fasthttp.RequestHandler
```
## ErrorHandler
`ErrorHandler` executes the process defined for the application in case of errors. This is used in some cases in middlewares.
```go title="Signature"
func (app *App) ErrorHandler(ctx Ctx, err error) error
```
## NewWithCustomCtx
`NewWithCustomCtx` creates a new `*App` and sets the custom context factory
function at construction time.
```go title="Signature"
func NewWithCustomCtx(fn func(app *App) CustomCtx, config ...Config) *App
```
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
type CustomCtx struct {
fiber.DefaultCtx
}
func (c *CustomCtx) Params(key string, defaultValue ...string) string {
return "prefix_" + c.DefaultCtx.Params(key)
}
func main() {
app := fiber.NewWithCustomCtx(func(app *fiber.App) fiber.CustomCtx {
return &CustomCtx{
DefaultCtx: *fiber.NewDefaultCtx(app),
}
})
app.Get("/:id", func(c fiber.Ctx) error {
return c.SendString(c.Params("id"))
})
log.Fatal(app.Listen(":3000"))
}
```
## RegisterCustomBinder
You can register custom binders to use with [`Bind().Custom("name")`](bind.md#custom). They should be compatible with the `CustomBinder` interface.
```go title="Signature"
func (app *App) RegisterCustomBinder(binder CustomBinder)
```
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
"gopkg.in/yaml.v2"
)
type User struct {
Name string `yaml:"name"`
}
type customBinder struct{}
func (*customBinder) Name() string {
return "custom"
}
func (*customBinder) MIMETypes() []string {
return []string{"application/yaml"}
}
func (*customBinder) Parse(c fiber.Ctx, out any) error {
// Parse YAML body
return yaml.Unmarshal(c.Body(), out)
}
func main() {
app := fiber.New()
// Register custom binder
app.RegisterCustomBinder(&customBinder{})
app.Post("/custom", func(c fiber.Ctx) error {
var user User
// Use Custom binder by name
if err := c.Bind().Custom("custom", &user); err != nil {
return err
}
return c.JSON(user)
})
app.Post("/normal", func(c fiber.Ctx) error {
var user User
// Custom binder is used by the MIME type
if err := c.Bind().Body(&user); err != nil {
return err
}
return c.JSON(user)
})
log.Fatal(app.Listen(":3000"))
}
```
## RegisterCustomConstraint
`RegisterCustomConstraint` allows you to register custom constraints.
```go title="Signature"
func (app *App) RegisterCustomConstraint(constraint CustomConstraint)
```
See the [Custom Constraint](../guide/routing.md#custom-constraint) section for more information.
## SetTLSHandler
Use `SetTLSHandler` to set [`ClientHelloInfo`](https://datatracker.ietf.org/doc/html/rfc8446#section-4.1.2) when using TLS with a `Listener`.
```go title="Signature"
func (app *App) SetTLSHandler(tlsHandler *TLSHandler)
```
## Test
Testing your application is done with the `Test` method. Use this method for creating `_test.go` files or when you need to debug your routing logic. The default timeout is `1s`; to disable a timeout altogether, pass a `TestConfig` struct with `Timeout: 0`.
```go title="Signature"
func (app *App) Test(req *http.Request, config ...TestConfig) (*http.Response, error)
```
```go title="Example"
package main
import (
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
// Create route with GET method for test:
app.Get("/", func(c fiber.Ctx) error {
fmt.Println(c.BaseURL()) // => http://google.com
fmt.Println(c.Get("X-Custom-Header")) // => hi
return c.SendString("hello, World!")
})
// Create http.Request
req := httptest.NewRequest("GET", "http://google.com", nil)
req.Header.Set("X-Custom-Header", "hi")
// Perform the test
resp, _ := app.Test(req)
// Do something with the results:
if resp.StatusCode == fiber.StatusOK {
body, _ := io.ReadAll(resp.Body)
fmt.Println(string(body)) // => hello, World!
}
}
```
If not provided, TestConfig is set to the following defaults:
```go title="Default TestConfig"
config := fiber.TestConfig{
Timeout: time.Second,
FailOnTimeout: true,
}
```
:::caution
This is **not** the same as supplying an empty `TestConfig{}` to
`app.Test(), but rather be the equivalent of supplying:
```go title="Empty TestConfig"
cfg := fiber.TestConfig{
Timeout: 0,
FailOnTimeout: false,
}
```
This would make a Test that has no timeout.
:::
## Hooks
`Hooks` is a method to return the [hooks](./hooks.md) property.
```go title="Signature"
func (app *App) Hooks() *Hooks
```
## RebuildTree
The `RebuildTree` method is designed to rebuild the route tree and enable dynamic route registration. It returns a pointer to the `App` instance.
```go title="Signature"
func (app *App) RebuildTree() *App
```
**Note:** Use this method with caution. It is **not** thread-safe and calling it can be very performance-intensive, so it should be used sparingly and only in development mode. Avoid using it concurrently.
### Example Usage
Here’s an example of how to define and register routes dynamically:
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Get("/define", func(c fiber.Ctx) error {
// Define a new route dynamically
app.Get("/dynamically-defined", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
// Rebuild the route tree to register the new route
app.RebuildTree()
return c.SendStatus(fiber.StatusOK)
})
log.Fatal(app.Listen(":3000"))
}
```
In this example, a new route is defined and then `RebuildTree()` is called to ensure the new route is registered and available.
## RemoveRoute
This method removes a route by path. You must call the `RebuildTree()` method after the removal to finalize the update and rebuild the routing tree.
If no methods are specified, the route will be removed for all HTTP methods defined in the app. To limit removal to specific methods, provide them as additional arguments.
```go title="Signature"
func (app *App) RemoveRoute(path string, methods ...string)
```
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Get("/api/feature-a", func(c fiber.Ctx) error {
app.RemoveRoute("/api/feature", fiber.MethodGet)
app.RebuildTree()
// Redefine route
app.Get("/api/feature", func(c fiber.Ctx) error {
return c.SendString("Testing feature-a")
})
app.RebuildTree()
return c.SendStatus(fiber.StatusOK)
})
app.Get("/api/feature-b", func(c fiber.Ctx) error {
app.RemoveRoute("/api/feature", fiber.MethodGet)
app.RebuildTree()
// Redefine route
app.Get("/api/feature", func(c fiber.Ctx) error {
return c.SendString("Testing feature-b")
})
app.RebuildTree()
return c.SendStatus(fiber.StatusOK)
})
log.Fatal(app.Listen(":3000"))
}
```
## RemoveRouteByName
This method removes a route by name.
If no methods are specified, the route will be removed for all HTTP methods defined in the app. To limit removal to specific methods, provide them as additional arguments.
```go title="Signature"
func (app *App) RemoveRouteByName(name string, methods ...string)
```
## RemoveRouteFunc
This method removes a route by function having `*Route` parameter.
If no methods are specified, the route will be removed for all HTTP methods defined in the app. To limit removal to specific methods, provide them as additional arguments.
```go title="Signature"
func (app *App) RemoveRouteFunc(matchFunc func(r *Route) bool, methods ...string)
```
================================================
FILE: docs/api/bind.md
================================================
---
id: bind
title: 📎 Bind
description: Binds the request and response items to a struct.
sidebar_position: 4
toc_max_heading_level: 4
---
Bindings parse request and response bodies, query parameters, cookies, and more into structs.
:::info
Binder-returned values are valid only within the handler. To keep them, copy the data
or enable the [**`Immutable`**](./ctx.md) setting. [Read more...](../#zero-allocation)
:::
## Binders
- [All](#all)
- [Body](#body)
- [CBOR](#cbor)
- [Form](#form)
- [JSON](#json)
- [MsgPack](#msgpack)
- [XML](#xml)
- [Cookie](#cookie)
- [Header](#header)
- [Query](#query)
- [RespHeader](#respheader)
- [URI](#uri)
### All
The `All` function binds data from URL parameters, the request body, query parameters, headers, and cookies into `out`. Sources are applied in the following order using struct field tags.
#### Precedence Order
The binding sources have the following precedence:
1. **URL Parameters (URI)**
2. **Request Body (e.g., JSON or form data)**
3. **Query Parameters**
4. **Request Headers**
5. **Cookies**
:::info
The request body is only included as a binding source when the request has both a non-empty body **and** a non-empty `Content-Type` header.
:::
```go title="Signature"
func (b *Bind) All(out any) error
```
```go title="Example"
type User struct {
Name string `query:"name" json:"name" form:"name"`
Email string `json:"email" form:"email"`
Role string `header:"X-User-Role"`
SessionID string `json:"session_id" cookie:"session_id"`
ID int `uri:"id" query:"id" json:"id" form:"id"`
}
app.Post("/users", func(c fiber.Ctx) error {
user := new(User)
if err := c.Bind().All(user); err != nil {
return err
}
// All available data is now bound to the user struct
return c.JSON(user)
})
```
### Body
Binds the request body to a struct.
Use tags that match the content type. For example, to parse a JSON body with a `Pass` field, declare `json:"pass"`.
| Content-Type | Struct Tag |
| ----------------------------------- | ---------- |
| `application/x-www-form-urlencoded` | `form` |
| `multipart/form-data` | `form` |
| `application/json` | `json` |
| `application/xml` | `xml` |
| `text/xml` | `xml` |
| `application/vnd.msgpack` | `msgpack` |
```go title="Signature"
func (b *Bind) Body(out any) error
```
```go title="Example"
type Person struct {
Name string `json:"name" xml:"name" form:"name" msgpack:"name"`
Pass string `json:"pass" xml:"pass" form:"pass" msgpack:"pass"`
}
app.Post("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().Body(p); err != nil {
return err
}
log.Println(p.Name) // john
log.Println(p.Pass) // doe
// ...
})
```
Test the handler with these `curl` commands:
```bash
# JSON
curl -X POST -H "Content-Type: application/json" --data "{\"name\":\"john\",\"pass\":\"doe\"}" localhost:3000
# MsgPack
curl -X POST -H "Content-Type: application/vnd.msgpack" --data-binary $'\x82\xa4name\xa4john\xa4pass\xa3doe' localhost:3000
# XML
curl -X POST -H "Content-Type: application/xml" --data "johndoe" localhost:3000
# Form URL-Encoded
curl -X POST -H "Content-Type: application/x-www-form-urlencoded" --data "name=john&pass=doe" localhost:3000
# Multipart Form
curl -X POST -F name=john -F pass=doe http://localhost:3000
```
### CBOR
> **Note:** Before using any CBOR-related features, make sure to follow the [CBOR setup instructions](../guide/advance-format.md#cbor).
Binds the request CBOR body to a struct.
It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a CBOR body with a field called `Pass`, you would use a struct field with `cbor:"pass"`.
```go title="Signature"
func (b *Bind) CBOR(out any) error
```
```go title="Example"
// Field names should start with an uppercase letter
type Person struct {
Name string `cbor:"name"`
Pass string `cbor:"pass"`
}
app.Post("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().CBOR(p); err != nil {
return err
}
log.Println(p.Name) // john
log.Println(p.Pass) // doe
// ...
})
```
Test the defaults with this `curl` command:
```bash
curl -X POST -H "Content-Type: application/cbor" --data "\xa2dnamedjohndpasscdoe" localhost:3000
```
### Form
Binds the request or multipart form body data to a struct.
It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a form body with a field called `Pass`, you would use a struct field with `form:"pass"`.
```go title="Signature"
func (b *Bind) Form(out any) error
```
```go title="Example"
type Person struct {
Name string `form:"name"`
Pass string `form:"pass"`
}
app.Post("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().Form(p); err != nil {
return err
}
log.Println(p.Name) // john
log.Println(p.Pass) // doe
// ...
})
```
Run tests with the following `curl` commands for both `application/x-www-form-urlencoded` and `multipart/form-data`:
```bash
curl -X POST -H "Content-Type: application/x-www-form-urlencoded" --data "name=john&pass=doe" localhost:3000
```
```bash
curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" localhost:3000
```
:::info
If you need to bind multipart file, you can use `*multipart.FileHeader`, `*[]*multipart.FileHeader` or `[]*multipart.FileHeader` as a field type.
:::
```go title="Example"
type Person struct {
Name string `form:"name"`
Pass string `form:"pass"`
Avatar *multipart.FileHeader `form:"avatar"`
}
app.Post("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().Form(p); err != nil {
return err
}
log.Println(p.Name) // john
log.Println(p.Pass) // doe
log.Println(p.Avatar.Filename) // file.txt
// ...
})
```
Run tests with the following `curl` command:
```bash
curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" -F 'avatar=@filename' localhost:3000
```
### JSON
Binds the request JSON body to a struct.
It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a JSON body with a field called `Pass`, you would use a struct field with `json:"pass"`.
```go title="Signature"
func (b *Bind) JSON(out any) error
```
```go title="Example"
type Person struct {
Name string `json:"name"`
Pass string `json:"pass"`
}
app.Post("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().JSON(p); err != nil {
return err
}
log.Println(p.Name) // john
log.Println(p.Pass) // doe
// ...
})
```
Run tests with the following `curl` command:
```bash
curl -X POST -H "Content-Type: application/json" --data "{\"name\":\"john\",\"pass\":\"doe\"}" localhost:3000
```
### MsgPack
> **Note:** Before using any MsgPack-related features, make sure to follow the [MsgPack setup instructions](../guide/advance-format.md#msgpack).
Binds the request MsgPack body to a struct.
It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a Msgpack body with a field called `Pass`, you would use a struct field with `msgpack:"pass"`.
> Our library uses [shamaton-msgpack](https://github.com/shamaton/msgpack) which uses `msgpack` struct tags by default. If you want to use other libraries, you may need to update the struct tags accordingly.
```go title="Signature"
func (b *Bind) MsgPack(out any) error
```
```go title="Example"
type Person struct {
Name string `msgpack:"name"`
Pass string `msgpack:"pass"`
}
app.Post("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().MsgPack(p); err != nil {
return err
}
log.Println(p.Name) // john
log.Println(p.Pass) // doe
// ...
})
```
Run tests with the following `curl` command:
```bash
curl -X POST -H "Content-Type: application/vnd.msgpack" --data-binary $'\x82\xa4name\xa4john\xa4pass\xa3doe' localhost:3000
```
### XML
Binds the request XML body to a struct.
It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse an XML body with a field called `Pass`, you would use a struct field with `xml:"pass"`.
```go title="Signature"
func (b *Bind) XML(out any) error
```
```go title="Example"
// Field names should start with an uppercase letter
type Person struct {
Name string `xml:"name"`
Pass string `xml:"pass"`
}
app.Post("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().XML(p); err != nil {
return err
}
log.Println(p.Name) // john
log.Println(p.Pass) // doe
// ...
})
```
Run tests with the following `curl` command:
```bash
curl -X POST -H "Content-Type: application/xml" --data "johndoe" localhost:3000
```
### Cookie
This method is similar to [Body Binding](#body), but for cookie parameters.
It is important to use the struct tag `cookie`. For example, if you want to parse a cookie with a field called `Age`, you would use a struct field with `cookie:"age"`.
```go title="Signature"
func (b *Bind) Cookie(out any) error
```
```go title="Example"
type Person struct {
Name string `cookie:"name"`
Age int `cookie:"age"`
Job bool `cookie:"job"`
}
app.Get("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().Cookie(p); err != nil {
return err
}
log.Println(p.Name) // Joseph
log.Println(p.Age) // 23
log.Println(p.Job) // true
})
```
Run tests with the following `curl` command:
```bash
curl --cookie "name=Joseph; age=23; job=true" http://localhost:8000/
```
### Header
This method is similar to [Body Binding](#body), but for request headers.
It is important to use the struct tag `header`. For example, if you want to parse a request header with a field called `Pass`, you would use a struct field with `header:"pass"`.
```go title="Signature"
func (b *Bind) Header(out any) error
```
```go title="Example"
type Person struct {
Name string `header:"name"`
Pass string `header:"pass"`
Products []string `header:"products"`
}
app.Get("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().Header(p); err != nil {
return err
}
log.Println(p.Name) // john
log.Println(p.Pass) // doe
log.Println(p.Products) // [shoe hat]
// ...
})
```
Run tests with the following `curl` command:
```bash
curl "http://localhost:3000/" -H "name: john" -H "pass: doe" -H "products: shoe,hat"
```
### Query
This method is similar to [Body Binding](#body), but for query parameters.
It is important to use the struct tag `query`. For example, if you want to parse a query parameter with a field called `Pass`, you would use a struct field with `query:"pass"`.
```go title="Signature"
func (b *Bind) Query(out any) error
```
```go title="Example"
type Person struct {
Name string `query:"name"`
Pass string `query:"pass"`
Products []string `query:"products"`
}
app.Get("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().Query(p); err != nil {
return err
}
log.Println(p.Name) // john
log.Println(p.Pass) // doe
// Depending on fiber.Config{EnableSplittingOnParsers: false} - default
log.Println(p.Products) // ["shoe,hat"]
// With fiber.Config{EnableSplittingOnParsers: true}
// log.Println(p.Products) // ["shoe", "hat"]
// ...
})
```
Run tests with the following `curl` command:
```bash
curl "http://localhost:3000/?name=john&pass=doe&products=shoe,hat"
```
:::info
For more parser settings, please refer to [Config](fiber.md#enablesplittingonparsers)
:::
#### Array Query Parameters
Fiber supports several formats for passing array values via query parameters. The following table gives an overview:
| Format | Example | Requires `EnableSplittingOnParsers` |
| ------------------------ | ---------------------------------------------- | ----------------------------------- |
| Repeated key | `?colors=red&colors=blue` | No |
| Bracket notation | `?colors[]=red&colors[]=blue` | No |
| Comma-separated | `?colors=red,blue` | **Yes** |
| Indexed bracket notation | `?posts[0][title]=Hello&posts[1][title]=World` | No |
| Nested bracket notation | `?preferences[tags]=golang,api` | No (comma splitting: **Yes**) |
##### Repeated Key
The most common approach. Repeat the same query key for each value:
```text
GET /api?colors=red&colors=blue&colors=green
```
```go title="Struct"
type Filter struct {
Colors []string `query:"colors"`
}
// Result: Colors = ["red", "blue", "green"]
```
```bash title="curl"
curl "http://localhost:3000/api?colors=red&colors=blue&colors=green"
```
##### Bracket Notation
Append `[]` to the key name. This is common in PHP-style and JavaScript frameworks:
```text
GET /api?colors[]=red&colors[]=blue&colors[]=green
```
```go title="Struct"
type Filter struct {
Colors []string `query:"colors"`
}
// Result: Colors = ["red", "blue", "green"]
```
```bash title="curl"
curl "http://localhost:3000/api?colors[]=red&colors[]=blue&colors[]=green"
```
:::note
The struct field tag stays `query:"colors"` (without brackets). Fiber strips the `[]` automatically.
:::
##### Comma-Separated Values
Pass multiple values in a single parameter, separated by commas. This format requires [`EnableSplittingOnParsers`](fiber.md#enablesplittingonparsers) to be set to `true`.
```text
GET /api?colors=red,blue,green
```
```go title="Struct"
type Filter struct {
Colors []string `query:"colors"`
}
```
```go title="App Setup"
// EnableSplittingOnParsers is required for comma splitting
app := fiber.New(fiber.Config{
EnableSplittingOnParsers: true,
})
// Result: Colors = ["red", "blue", "green"]
```
Without `EnableSplittingOnParsers`, the entire string `"red,blue,green"` is treated as a **single** element.
```go title="Default behavior (EnableSplittingOnParsers: false)"
// GET /api?colors=red,blue,green
// Result: Colors = ["red,blue,green"] ← single element
```
```bash title="curl"
curl "http://localhost:3000/api?colors=red,blue,green"
```
You can also mix comma-separated values with repeated keys when splitting is enabled:
```text
GET /api?hobby=soccer&hobby=basketball,football
```
```go
type Query struct {
Hobby []string `query:"hobby"`
}
// With EnableSplittingOnParsers: true
// Result: Hobby = ["soccer", "basketball", "football"] ← 3 elements
```
##### Indexed Bracket Notation (Nested Structs)
Use indexed brackets to bind arrays of nested structs:
```text
GET /api?posts[0][title]=Hello&posts[0][author]=Alice&posts[1][title]=World&posts[1][author]=Bob
```
```go title="Struct"
type Post struct {
Title string `query:"title"`
Author string `query:"author"`
}
type Request struct {
Posts []Post `query:"posts"`
}
// Result: Posts = [{Title: "Hello", Author: "Alice"}, {Title: "World", Author: "Bob"}]
```
```bash title="curl"
curl "http://localhost:3000/api?posts[0][title]=Hello&posts[0][author]=Alice&posts[1][title]=World&posts[1][author]=Bob"
```
##### Nested Bracket Notation (Without Index)
Use bracket notation to access fields of a nested struct:
```text
GET /api?preferences[tags]=golang,api
```
```go title="Struct"
type Preferences struct {
Tags *[]string `query:"tags"`
}
type Profile struct {
Prefs *Preferences `query:"preferences"`
}
// With EnableSplittingOnParsers: true
// Result: *Prefs.Tags = ["golang", "api"]
```
```bash title="curl"
curl "http://localhost:3000/api?preferences[tags]=golang,api"
```
:::note
Pointer fields (`*[]string`, `*Preferences`) let you distinguish between a missing parameter (`nil`) and an empty one. When the parameter is present, Fiber allocates the pointer automatically.
:::
### RespHeader
This method is similar to [Body Binding](#body), but for response headers.
It is important to use the struct tag `respHeader`. For example, if you want to parse a response header with a field called `Pass`, you would use a struct field with `respHeader:"pass"`.
```go title="Signature"
func (b *Bind) RespHeader(out any) error
```
```go title="Example"
type Person struct {
Name string `respHeader:"name"`
Pass string `respHeader:"pass"`
Products []string `respHeader:"products"`
}
app.Get("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().RespHeader(p); err != nil {
return err
}
log.Println(p.Name) // john
log.Println(p.Pass) // doe
log.Println(p.Products) // [shoe hat]
// ...
})
```
Run tests with the following `curl` command:
```bash
curl "http://localhost:3000/" -H "name: john" -H "pass: doe" -H "products: shoe,hat"
```
### URI
This method is similar to [Body Binding](#body), but for path parameters.
It is important to use the struct tag `uri`. For example, if you want to parse a path parameter with a field called `Pass`, you would use a struct field with `uri:"pass"`.
```go title="Signature"
func (b *Bind) URI(out any) error
```
```go title="Example"
// GET http://example.com/user/111
app.Get("/user/:id", func(c fiber.Ctx) error {
param := struct {
ID uint `uri:"id"`
}{}
if err := c.Bind().URI(¶m); err != nil {
return err
}
// ...
return c.SendString(fmt.Sprintf("User ID: %d", param.ID))
})
```
## BindError
When a bind method fails to parse (e.g. invalid JSON, bad type conversion), the behavior depends on the error-handling mode. In **manual handling** (the default), the binder returns a `*BindError` wrapping the underlying error — use `errors.As` to extract it and branch on the binding source or field. In **automatic handling** (enabled via `WithAutoHandling`), parse failures are instead converted to a `*fiber.Error` with HTTP status 400; `*BindError` is never surfaced to the caller in that mode. If you are using `WithAutoHandling`, check for `*fiber.Error` or an HTTP 400 response rather than using `errors.As` for `*BindError`.
```go
type BindError struct {
Source string // "uri", "query", "body", "header", "cookie", or "respHeader"
Field string // struct field or tag key that failed (best-effort, may be empty)
Err error // underlying error; use errors.As to inspect
}
```
Source constants: `BindSourceURI`, `BindSourceQuery`, `BindSourceHeader`, `BindSourceCookie`, `BindSourceBody`, `BindSourceRespHeader`.
### Branching on source
Use `errors.As` to extract `*BindError` and branch on `Source` for RFC-correct status codes (e.g. 404 for URI failures vs 400 for body/query):
```go title="Example"
// With manual handling mode (default behavior)
// Will not work with WithAutoHandling()
var req struct {
ID int `uri:"id"`
Name string `json:"name"`
}
if err := c.Bind().All(&req); err != nil {
var be *fiber.BindError
if errors.As(err, &be) && be.Source == fiber.BindSourceURI {
return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "not found"})
}
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request"})
}
```
### Validation vs binding errors
Validation errors (from `StructValidator`) are **not** wrapped in `BindError`. Use `errors.As(err, &be)` to distinguish: it succeeds only for parsing/binding failures, not for validation failures.
## Custom
To use custom binders, you have to use this method.
You can register them using the [RegisterCustomBinder](./app.md#registercustombinder) method of the Fiber instance.
```go title="Signature"
func (b *Bind) Custom(name string, dest any) error
```
```go title="Example"
app := fiber.New()
// My custom binder
type customBinder struct{}
func (cb *customBinder) Name() string {
return "custom"
}
func (cb *customBinder) MIMETypes() []string {
return []string{"application/yaml"}
}
func (cb *customBinder) Parse(c fiber.Ctx, out any) error {
// parse YAML body
return yaml.Unmarshal(c.Body(), out)
}
// Register custom binder
app.RegisterCustomBinder(&customBinder{})
type User struct {
Name string `yaml:"name"`
}
// curl -X POST http://localhost:3000/custom -H "Content-Type: application/yaml" -d "name: John"
app.Post("/custom", func(c fiber.Ctx) error {
var user User
// Use Custom binder by name
if err := c.Bind().Custom("custom", &user); err != nil {
return err
}
return c.JSON(user)
})
```
Internally, custom binders are also used in the [Body](#body) method.
The `MIMETypes` method is used to check if the custom binder should be used for the given content type.
## Options
For more control over error handling, you can use the following methods.
### WithAutoHandling
If you want to handle binder errors automatically, you can use `WithAutoHandling`.
If there's an error, it will return the error and set HTTP status to `400 Bad Request`.
This function does NOT panic therefore you must still return on error explicitly
```go title="Signature"
func (b *Bind) WithAutoHandling() *Bind
```
### WithoutAutoHandling
To handle binder errors manually, you can use the `WithoutAutoHandling` method.
It's the default behavior of the binder.
```go title="Signature"
func (b *Bind) WithoutAutoHandling() *Bind
```
### SkipValidation
To enable or disable validation for the current bind chain, use `SkipValidation`.
By default, validation is enabled (`skip = false`).
```go title="Signature"
func (b *Bind) SkipValidation(skip bool) *Bind
```
## SetParserDecoder
Allows you to configure the BodyParser/QueryParser decoder based on schema options, providing the possibility to add custom types for parsing.
```go title="Signature"
func SetParserDecoder(parserConfig binder.ParserConfig)
```
`binder.ParserConfig` has the following fields:
```go
type ParserConfig struct {
IgnoreUnknownKeys bool
ParserType []ParserType
ZeroEmpty bool
SetAliasTag string
}
type ParserType struct {
CustomType any
Converter func(string) reflect.Value
}
```
```go title="Example"
type CustomTime time.Time
// String returns the time in string format
func (ct *CustomTime) String() string {
t := time.Time(*ct).String()
return t
}
// Converter for CustomTime type with format "2006-01-02"
var timeConverter = func(value string) reflect.Value {
fmt.Println("timeConverter:", value)
if v, err := time.Parse("2006-01-02", value); err == nil {
return reflect.ValueOf(CustomTime(v))
}
return reflect.Value{}
}
customTime := binder.ParserType{
CustomType: CustomTime{},
Converter: timeConverter,
}
// Add custom type to the Decoder settings
binder.SetParserDecoder(binder.ParserConfig{
IgnoreUnknownKeys: true,
ParserType: []binder.ParserType{customTime},
ZeroEmpty: true,
})
// Example using CustomTime with non-RFC3339 format
type Demo struct {
Date CustomTime `form:"date" query:"date"`
Title string `form:"title" query:"title"`
Body string `form:"body" query:"body"`
}
app.Post("/body", func(c fiber.Ctx) error {
var d Demo
if err := c.Bind().Body(&d); err != nil {
return err
}
fmt.Println("d.Date:", d.Date.String())
return c.JSON(d)
})
app.Get("/query", func(c fiber.Ctx) error {
var d Demo
if err := c.Bind().Query(&d); err != nil {
return err
}
fmt.Println("d.Date:", d.Date.String())
return c.JSON(d)
})
// Run tests with the following curl commands:
# Body Binding
curl -X POST -F title=title -F body=body -F date=2021-10-20 http://localhost:3000/body
# Query Binding
curl -X GET "http://localhost:3000/query?title=title&body=body&date=2021-10-20"
```
## Validation
Validation is also possible with the binding methods. You can specify your validation rules using the `validate` struct tag.
Specify your struct validator in the [config](./fiber.md#structvalidator).
The validator must implement the `StructValidator` interface, which requires a `Validate` method that takes an `any` type and returns an error.
```go title="Interface"
type StructValidator interface {
Validate(out any) error
}
```
### Setup Your Validator in the Config
```go title="Example"
import "github.com/go-playground/validator/v10"
type structValidator struct {
validate *validator.Validate
}
// Validate method implementation
func (v *structValidator) Validate(out any) error {
return v.validate.Struct(out)
}
// Setup your validator in the Fiber config
app := fiber.New(fiber.Config{
StructValidator: &structValidator{validate: validator.New()},
})
```
Fiber only runs `StructValidator` for struct destinations (or pointers to structs).
Binding into maps and other non-struct types skips the validator step.
### Usage of Validation in Binding Methods
```go title="Example"
type Person struct {
Name string `json:"name" validate:"required"`
Age int `json:"age" validate:"gte=18,lte=60"`
}
app.Post("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().JSON(p); err != nil { // Receives validation errors
return err
}
})
```
## Default Fields
You can set default values for fields in the struct by using the `default` struct tag. Supported types:
- `bool`
- Float variants (`float32`, `float64`)
- Int variants (`int`, `int8`, `int16`, `int32`, `int64`)
- Uint variants (`uint`, `uint8`, `uint16`, `uint32`, `uint64`)
- `string`
- A slice of the above types. Use `|` to separate slice items.
- A pointer to one of the above types (**pointers to slices and slices of pointers are not supported**).
```go title="Example"
type Person struct {
Name string `query:"name,default:john"`
Pass string `query:"pass"`
Products []string `query:"products,default:shoe|hat"`
}
app.Get("/", func(c fiber.Ctx) error {
p := new(Person)
if err := c.Bind().Query(p); err != nil {
return err
}
log.Println(p.Name) // john
log.Println(p.Pass) // doe
log.Println(p.Products) // ["shoe", "hat"]
// ...
})
```
Run tests with the following `curl` command:
```bash
curl "http://localhost:3000/?pass=doe"
```
================================================
FILE: docs/api/constants.md
================================================
---
id: constants
title: 📋 Constants
description: Core HTTP constants used throughout Fiber.
sidebar_position: 10
---
### HTTP methods (mirrors `net/http`)
```go
const (
MethodGet = "GET" // RFC 7231, 4.3.1
MethodHead = "HEAD" // RFC 7231, 4.3.2
MethodPost = "POST" // RFC 7231, 4.3.3
MethodPut = "PUT" // RFC 7231, 4.3.4
MethodPatch = "PATCH" // RFC 5789
MethodDelete = "DELETE" // RFC 7231, 4.3.5
MethodConnect = "CONNECT" // RFC 7231, 4.3.6
MethodOptions = "OPTIONS" // RFC 7231, 4.3.7
MethodTrace = "TRACE" // RFC 7231, 4.3.8
methodUse = "USE"
)
```
### Common MIME types
```go
const (
MIMETextXML = "text/xml"
MIMETextHTML = "text/html"
MIMETextPlain = "text/plain"
MIMETextJavaScript = "text/javascript"
MIMETextCSS = "text/css"
MIMEApplicationXML = "application/xml"
MIMEApplicationJSON = "application/json"
MIMEApplicationCBOR = "application/cbor"
MIMEApplicationForm = "application/x-www-form-urlencoded"
MIMEOctetStream = "application/octet-stream"
MIMEMultipartForm = "multipart/form-data"
MIMETextXMLCharsetUTF8 = "text/xml; charset=utf-8"
MIMETextHTMLCharsetUTF8 = "text/html; charset=utf-8"
MIMETextPlainCharsetUTF8 = "text/plain; charset=utf-8"
MIMETextJavaScriptCharsetUTF8 = "text/javascript; charset=utf-8"
MIMETextCSSCharsetUTF8 = "text/css; charset=utf-8"
MIMEApplicationXMLCharsetUTF8 = "application/xml; charset=utf-8"
MIMEApplicationJSONCharsetUTF8 = "application/json; charset=utf-8"
)
```
### HTTP status codes (mirrors `net/http`)
```go
const (
StatusContinue = 100 // RFC 7231, 6.2.1
StatusSwitchingProtocols = 101 // RFC 7231, 6.2.2
StatusProcessing = 102 // RFC 2518, 10.1
StatusEarlyHints = 103 // RFC 8297
StatusOK = 200 // RFC 7231, 6.3.1
StatusCreated = 201 // RFC 7231, 6.3.2
StatusAccepted = 202 // RFC 7231, 6.3.3
StatusNonAuthoritativeInformation = 203 // RFC 7231, 6.3.4
StatusNoContent = 204 // RFC 7231, 6.3.5
StatusResetContent = 205 // RFC 7231, 6.3.6
StatusPartialContent = 206 // RFC 7233, 4.1
StatusMultiStatus = 207 // RFC 4918, 11.1
StatusAlreadyReported = 208 // RFC 5842, 7.1
StatusIMUsed = 226 // RFC 3229, 10.4.1
StatusMultipleChoices = 300 // RFC 7231, 6.4.1
StatusMovedPermanently = 301 // RFC 7231, 6.4.2
StatusFound = 302 // RFC 7231, 6.4.3
StatusSeeOther = 303 // RFC 7231, 6.4.4
StatusNotModified = 304 // RFC 7232, 4.1
StatusUseProxy = 305 // RFC 7231, 6.4.5
StatusSwitchProxy = 306 // RFC 9110, 15.4.7 (Unused)
StatusTemporaryRedirect = 307 // RFC 7231, 6.4.7
StatusPermanentRedirect = 308 // RFC 7538, 3
StatusBadRequest = 400 // RFC 7231, 6.5.1
StatusUnauthorized = 401 // RFC 7235, 3.1
StatusPaymentRequired = 402 // RFC 7231, 6.5.2
StatusForbidden = 403 // RFC 7231, 6.5.3
StatusNotFound = 404 // RFC 7231, 6.5.4
StatusMethodNotAllowed = 405 // RFC 7231, 6.5.5
StatusNotAcceptable = 406 // RFC 7231, 6.5.6
StatusProxyAuthRequired = 407 // RFC 7235, 3.2
StatusRequestTimeout = 408 // RFC 7231, 6.5.7
StatusConflict = 409 // RFC 7231, 6.5.8
StatusGone = 410 // RFC 7231, 6.5.9
StatusLengthRequired = 411 // RFC 7231, 6.5.10
StatusPreconditionFailed = 412 // RFC 7232, 4.2
StatusRequestEntityTooLarge = 413 // RFC 7231, 6.5.11
StatusRequestURITooLong = 414 // RFC 7231, 6.5.12
StatusUnsupportedMediaType = 415 // RFC 7231, 6.5.13
StatusRequestedRangeNotSatisfiable = 416 // RFC 7233, 4.4
StatusExpectationFailed = 417 // RFC 7231, 6.5.14
StatusTeapot = 418 // RFC 7168, 2.3.3
StatusMisdirectedRequest = 421 // RFC 7540, 9.1.2
StatusUnprocessableEntity = 422 // RFC 4918, 11.2
StatusLocked = 423 // RFC 4918, 11.3
StatusFailedDependency = 424 // RFC 4918, 11.4
StatusTooEarly = 425 // RFC 8470, 5.2.
StatusUpgradeRequired = 426 // RFC 7231, 6.5.15
StatusPreconditionRequired = 428 // RFC 6585, 3
StatusTooManyRequests = 429 // RFC 6585, 4
StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5
StatusUnavailableForLegalReasons = 451 // RFC 7725, 3
StatusInternalServerError = 500 // RFC 7231, 6.6.1
StatusNotImplemented = 501 // RFC 7231, 6.6.2
StatusBadGateway = 502 // RFC 7231, 6.6.3
StatusServiceUnavailable = 503 // RFC 7231, 6.6.4
StatusGatewayTimeout = 504 // RFC 7231, 6.6.5
StatusHTTPVersionNotSupported = 505 // RFC 7231, 6.6.6
StatusVariantAlsoNegotiates = 506 // RFC 2295, 8.1
StatusInsufficientStorage = 507 // RFC 4918, 11.5
StatusLoopDetected = 508 // RFC 5842, 7.2
StatusNotExtended = 510 // RFC 2774, 7
StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6
)
```
### Errors
```go
var (
ErrBadRequest = NewError(StatusBadRequest) // RFC 7231, 6.5.1
ErrUnauthorized = NewError(StatusUnauthorized) // RFC 7235, 3.1
ErrPaymentRequired = NewError(StatusPaymentRequired) // RFC 7231, 6.5.2
ErrForbidden = NewError(StatusForbidden) // RFC 7231, 6.5.3
ErrNotFound = NewError(StatusNotFound) // RFC 7231, 6.5.4
ErrMethodNotAllowed = NewError(StatusMethodNotAllowed) // RFC 7231, 6.5.5
ErrNotAcceptable = NewError(StatusNotAcceptable) // RFC 7231, 6.5.6
ErrProxyAuthRequired = NewError(StatusProxyAuthRequired) // RFC 7235, 3.2
ErrRequestTimeout = NewError(StatusRequestTimeout) // RFC 7231, 6.5.7
ErrConflict = NewError(StatusConflict) // RFC 7231, 6.5.8
ErrGone = NewError(StatusGone) // RFC 7231, 6.5.9
ErrLengthRequired = NewError(StatusLengthRequired) // RFC 7231, 6.5.10
ErrPreconditionFailed = NewError(StatusPreconditionFailed) // RFC 7232, 4.2
ErrRequestEntityTooLarge = NewError(StatusRequestEntityTooLarge) // RFC 7231, 6.5.11
ErrRequestURITooLong = NewError(StatusRequestURITooLong) // RFC 7231, 6.5.12
ErrUnsupportedMediaType = NewError(StatusUnsupportedMediaType) // RFC 7231, 6.5.13
ErrRequestedRangeNotSatisfiable = NewError(StatusRequestedRangeNotSatisfiable) // RFC 7233, 4.4
ErrExpectationFailed = NewError(StatusExpectationFailed) // RFC 7231, 6.5.14
ErrTeapot = NewError(StatusTeapot) // RFC 7168, 2.3.3
ErrMisdirectedRequest = NewError(StatusMisdirectedRequest) // RFC 7540, 9.1.2
ErrUnprocessableEntity = NewError(StatusUnprocessableEntity) // RFC 4918, 11.2
ErrLocked = NewError(StatusLocked) // RFC 4918, 11.3
ErrFailedDependency = NewError(StatusFailedDependency) // RFC 4918, 11.4
ErrTooEarly = NewError(StatusTooEarly) // RFC 8470, 5.2.
ErrUpgradeRequired = NewError(StatusUpgradeRequired) // RFC 7231, 6.5.15
ErrPreconditionRequired = NewError(StatusPreconditionRequired) // RFC 6585, 3
ErrTooManyRequests = NewError(StatusTooManyRequests) // RFC 6585, 4
ErrRequestHeaderFieldsTooLarge = NewError(StatusRequestHeaderFieldsTooLarge) // RFC 6585, 5
ErrUnavailableForLegalReasons = NewError(StatusUnavailableForLegalReasons) // RFC 7725, 3
ErrInternalServerError = NewError(StatusInternalServerError) // RFC 7231, 6.6.1
ErrNotImplemented = NewError(StatusNotImplemented) // RFC 7231, 6.6.2
ErrBadGateway = NewError(StatusBadGateway) // RFC 7231, 6.6.3
ErrServiceUnavailable = NewError(StatusServiceUnavailable) // RFC 7231, 6.6.4
ErrGatewayTimeout = NewError(StatusGatewayTimeout) // RFC 7231, 6.6.5
ErrHTTPVersionNotSupported = NewError(StatusHTTPVersionNotSupported) // RFC 7231, 6.6.6
ErrVariantAlsoNegotiates = NewError(StatusVariantAlsoNegotiates) // RFC 2295, 8.1
ErrInsufficientStorage = NewError(StatusInsufficientStorage) // RFC 4918, 11.5
ErrLoopDetected = NewError(StatusLoopDetected) // RFC 5842, 7.2
ErrNotExtended = NewError(StatusNotExtended) // RFC 2774, 7
ErrNetworkAuthenticationRequired = NewError(StatusNetworkAuthenticationRequired) // RFC 6585, 6
)
```
HTTP Headers were copied from net/http.
```go
const (
HeaderAuthorization = "Authorization"
HeaderProxyAuthenticate = "Proxy-Authenticate"
HeaderProxyAuthorization = "Proxy-Authorization"
HeaderWWWAuthenticate = "WWW-Authenticate"
HeaderAge = "Age"
HeaderCacheControl = "Cache-Control"
HeaderClearSiteData = "Clear-Site-Data"
HeaderExpires = "Expires"
HeaderPragma = "Pragma"
HeaderWarning = "Warning"
HeaderAcceptCH = "Accept-CH"
HeaderAcceptCHLifetime = "Accept-CH-Lifetime"
HeaderContentDPR = "Content-DPR"
HeaderDPR = "DPR"
HeaderEarlyData = "Early-Data"
HeaderSaveData = "Save-Data"
HeaderViewportWidth = "Viewport-Width"
HeaderWidth = "Width"
HeaderETag = "ETag"
HeaderIfMatch = "If-Match"
HeaderIfModifiedSince = "If-Modified-Since"
HeaderIfNoneMatch = "If-None-Match"
HeaderIfUnmodifiedSince = "If-Unmodified-Since"
HeaderLastModified = "Last-Modified"
HeaderVary = "Vary"
HeaderConnection = "Connection"
HeaderKeepAlive = "Keep-Alive"
HeaderAccept = "Accept"
HeaderAcceptCharset = "Accept-Charset"
HeaderAcceptEncoding = "Accept-Encoding"
HeaderAcceptLanguage = "Accept-Language"
HeaderCookie = "Cookie"
HeaderExpect = "Expect"
HeaderMaxForwards = "Max-Forwards"
HeaderSetCookie = "Set-Cookie"
HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials"
HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers"
HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods"
HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin"
HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers"
HeaderAccessControlMaxAge = "Access-Control-Max-Age"
HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers"
HeaderAccessControlRequestMethod = "Access-Control-Request-Method"
HeaderOrigin = "Origin"
HeaderTimingAllowOrigin = "Timing-Allow-Origin"
HeaderXPermittedCrossDomainPolicies = "X-Permitted-Cross-Domain-Policies"
HeaderDNT = "DNT"
HeaderTk = "Tk"
HeaderContentDisposition = "Content-Disposition"
HeaderContentEncoding = "Content-Encoding"
HeaderContentLanguage = "Content-Language"
HeaderContentLength = "Content-Length"
HeaderContentLocation = "Content-Location"
HeaderContentType = "Content-Type"
HeaderForwarded = "Forwarded"
HeaderVia = "Via"
HeaderXForwardedFor = "X-Forwarded-For"
HeaderXForwardedHost = "X-Forwarded-Host"
HeaderXForwardedProto = "X-Forwarded-Proto"
HeaderXForwardedProtocol = "X-Forwarded-Protocol"
HeaderXForwardedSsl = "X-Forwarded-Ssl"
HeaderXUrlScheme = "X-Url-Scheme"
HeaderLocation = "Location"
HeaderFrom = "From"
HeaderHost = "Host"
HeaderReferer = "Referer"
HeaderReferrerPolicy = "Referrer-Policy"
HeaderUserAgent = "User-Agent"
HeaderAllow = "Allow"
HeaderServer = "Server"
HeaderAcceptRanges = "Accept-Ranges"
HeaderContentRange = "Content-Range"
HeaderIfRange = "If-Range"
HeaderRange = "Range"
HeaderContentSecurityPolicy = "Content-Security-Policy"
HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only"
HeaderCrossOriginResourcePolicy = "Cross-Origin-Resource-Policy"
HeaderExpectCT = "Expect-CT"
HeaderFeaturePolicy = "Feature-Policy"
HeaderPublicKeyPins = "Public-Key-Pins"
HeaderPublicKeyPinsReportOnly = "Public-Key-Pins-Report-Only"
HeaderStrictTransportSecurity = "Strict-Transport-Security"
HeaderUpgradeInsecureRequests = "Upgrade-Insecure-Requests"
HeaderXContentTypeOptions = "X-Content-Type-Options"
HeaderXDownloadOptions = "X-Download-Options"
HeaderXFrameOptions = "X-Frame-Options"
HeaderXPoweredBy = "X-Powered-By"
HeaderXXSSProtection = "X-XSS-Protection"
HeaderLastEventID = "Last-Event-ID"
HeaderNEL = "NEL"
HeaderPingFrom = "Ping-From"
HeaderPingTo = "Ping-To"
HeaderReportTo = "Report-To"
HeaderTE = "TE"
HeaderTrailer = "Trailer"
HeaderTransferEncoding = "Transfer-Encoding"
HeaderSecWebSocketAccept = "Sec-WebSocket-Accept"
HeaderSecWebSocketExtensions = "Sec-WebSocket-Extensions"
HeaderSecWebSocketKey = "Sec-WebSocket-Key"
HeaderSecWebSocketProtocol = "Sec-WebSocket-Protocol"
HeaderSecWebSocketVersion = "Sec-WebSocket-Version"
HeaderAcceptPatch = "Accept-Patch"
HeaderAcceptPushPolicy = "Accept-Push-Policy"
HeaderAcceptSignature = "Accept-Signature"
HeaderAltSvc = "Alt-Svc"
HeaderDate = "Date"
HeaderIndex = "Index"
HeaderLargeAllocation = "Large-Allocation"
HeaderLink = "Link"
HeaderPushPolicy = "Push-Policy"
HeaderRetryAfter = "Retry-After"
HeaderServerTiming = "Server-Timing"
HeaderSignature = "Signature"
HeaderSignedHeaders = "Signed-Headers"
HeaderSourceMap = "SourceMap"
HeaderUpgrade = "Upgrade"
HeaderXDNSPrefetchControl = "X-DNS-Prefetch-Control"
HeaderXPingback = "X-Pingback"
HeaderXRequestID = "X-Request-ID"
HeaderXRequestedWith = "X-Requested-With"
HeaderXRobotsTag = "X-Robots-Tag"
HeaderXUACompatible = "X-UA-Compatible"
HeaderAccessControlAllowPrivateNetwork = "Access-Control-Allow-Private-Network"
HeaderAccessControlRequestPrivateNetwork = "Access-Control-Request-Private-Network"
)
```
================================================
FILE: docs/api/ctx.md
================================================
---
id: ctx
title: 🧠 Ctx
description: >-
The Ctx interface represents the Context which holds the HTTP request and
response. It has methods for the request query string, parameters, body, HTTP
headers, and so on.
sidebar_position: 3
---
### Abandon
Marks the context as abandoned. An abandoned context will not be returned to the pool when `ReleaseCtx` is called. This is used internally by the [timeout middleware](../middleware/timeout.md) to return immediately while the handler goroutine continues safely.
```go title="Signature"
func (c fiber.Ctx) Abandon()
func (c fiber.Ctx) IsAbandoned() bool
func (c fiber.Ctx) ForceRelease()
```
| Method | Description |
|:---------------|:----------------------------------------------------------------------------|
| `Abandon()` | Marks the context as abandoned. ReleaseCtx becomes a no-op for this context. |
| `IsAbandoned()`| Returns `true` if `Abandon()` was called on this context. |
| `ForceRelease()`| Releases an abandoned context back to the pool. Must only be called after the handler has completely finished. |
:::caution
These methods are primarily for internal use and advanced middleware development. Most applications should not need to call them directly.
:::
### App
Returns the [\*App](app.md) reference so you can easily access all application settings.
```go title="Signature"
func (c fiber.Ctx) App() *App
```
```go title="Example"
app.Get("/stack", func(c fiber.Ctx) error {
return c.JSON(c.App().Stack())
})
```
### Bind
Bind returns a helper for decoding the request body, query string, headers, cookies, and more.
For full details, see the [Bind](./bind.md) documentation.
```go title="Signature"
func (c fiber.Ctx) Bind() *Bind
```
```go title="Example"
app.Post("/", func(c fiber.Ctx) error {
user := new(User)
// Bind the request body to a struct:
return c.Bind().Body(user)
})
```
### Context
Returns a `context.Context` that was previously set with [`SetContext`](#setcontext).
If no context was set, it returns `context.Background()`. Unlike `fiber.Ctx` itself,
the returned context is safe to use after the handler completes.
```go title="Signature"
func (c fiber.Ctx) Context() context.Context
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
ctx := c.Context()
go doWork(ctx)
return nil
})
```
### context.Context
`Ctx` implements `context.Context`. However due to [current limitations in how fasthttp](https://github.com/valyala/fasthttp/issues/965#issuecomment-777268945) works, `Deadline()`, `Done()` and `Err()` are no-ops. The `fiber.Ctx` instance is reused after the handler returns and must not be used for asynchronous operations once the handler has completed. Call [`Context`](#context) within the handler to obtain a `context.Context` that can be used outside the handler.
```go title="Signature"
func (c fiber.Ctx) Deadline() (deadline time.Time, ok bool)
func (c fiber.Ctx) Done() <-chan struct{}
func (c fiber.Ctx) Err() error
func (c fiber.Ctx) Value(key any) any
```
```go title="Example"
func doSomething(ctx context.Context) {
// ...
}
app.Get("/", func(c fiber.Ctx) error {
doSomething(c)
})
```
#### Value
Value can be used to retrieve [**`Locals`**](#locals).
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Locals(userKey, "admin")
user := c.Value(userKey) // returns "admin"
})
```
### Drop
Terminates the client connection silently without sending any HTTP headers or response body.
This can be used for scenarios where you want to block certain requests without notifying the client, such as mitigating
DDoS attacks or protecting sensitive endpoints from unauthorized access.
```go title="Signature"
func (c fiber.Ctx) Drop() error
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
if c.IP() == "192.168.1.1" {
return c.Drop()
}
return c.SendString("Hello World!")
})
```
### FullPath
Returns the full path of the matched route. This includes any prefixes that were added by [groups](../guide/routing.md#grouping) or mounts.
```go title="Signature"
func (c fiber.Ctx) FullPath() string
```
```go title="Example"
api := app.Group("/api")
api.Get("/users/:id", func(c fiber.Ctx) error {
return c.JSON(fiber.Map{
"route": c.FullPath(), // "/api/users/:id"
})
})
app.Use(func(c fiber.Ctx) error {
beforeNext := c.FullPath() // "/"
if err := c.Next(); err != nil {
return err
}
afterNext := c.FullPath() // "/api/users/:id"
// ... react to the downstream handler's route path
return nil
})
```
### GetReqHeaders
Returns the HTTP request headers as a map. Because a header can appear multiple times in a request, each key maps to a slice with all values for that header.
```go title="Signature"
func (c fiber.Ctx) GetReqHeaders() map[string][]string
```
:::info
The returned value is valid only within the handler. Do not store references.
Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation)
:::
### GetRespHeader
Returns the HTTP response header specified by the field.
:::tip
The match is **case-insensitive**.
:::
```go title="Signature"
func (c fiber.Ctx) GetRespHeader(key string, defaultValue ...string) string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.GetRespHeader("X-Request-Id") // "8d7ad5e3-aaf3-450b-a241-2beb887efd54"
c.GetRespHeader("Content-Type") // "text/plain"
c.GetRespHeader("something", "john") // "john"
// ..
})
```
:::info
The returned value is valid only within the handler. Do not store references.
Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation)
:::
### GetRespHeaders
Returns the HTTP response headers as a map. Since a header can be set multiple times in a single request, the values of the map are slices of strings containing all the different values of the header.
```go title="Signature"
func (c fiber.Ctx) GetRespHeaders() map[string][]string
```
:::info
The returned value is valid only within the handler. Do not store references.
Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation)
:::
### GetRouteURL
Generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831"
```go title="Signature"
func (c fiber.Ctx) GetRouteURL(routeName string, params Map) (string, error)
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Home page")
}).Name("home")
app.Get("/user/:id", func(c fiber.Ctx) error {
return c.SendString(c.Params("id"))
}).Name("user.show")
app.Get("/test", func(c fiber.Ctx) error {
location, _ := c.GetRouteURL("user.show", fiber.Map{"id": 1})
return c.SendString(location)
})
// /test returns "/user/1"
```
### HasBody
Returns `true` if the incoming request contains a body or a `Content-Length` header greater than zero.
```go title="Signature"
func (c fiber.Ctx) HasBody() bool
```
```go title="Example"
app.Post("/", func(c fiber.Ctx) error {
if !c.HasBody() {
return c.SendStatus(fiber.StatusBadRequest)
}
return c.SendString("OK")
})
```
### IsMiddleware
Returns `true` if the current request handler was registered as middleware.
```go title="Signature"
func (c fiber.Ctx) IsMiddleware() bool
```
```go title="Example"
app.Get("/route", func(c fiber.Ctx) error {
fmt.Println(c.IsMiddleware()) // true
return c.Next()
}, func(c fiber.Ctx) error {
fmt.Println(c.IsMiddleware()) // false
return c.SendStatus(fiber.StatusOK)
})
```
### IsPreflight
Returns `true` if the request is a CORS preflight (`OPTIONS` + `Access-Control-Request-Method` + `Origin`).
```go title="Signature"
func (c fiber.Ctx) IsPreflight() bool
```
```go title="Example"
app.Use(func(c fiber.Ctx) error {
if c.IsPreflight() {
return c.SendStatus(fiber.StatusNoContent)
}
return c.Next()
})
```
### IsWebSocket
Returns `true` if the request includes a WebSocket upgrade handshake.
```go title="Signature"
func (c fiber.Ctx) IsWebSocket() bool
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
if c.IsWebSocket() {
// handle websocket
}
return c.Next()
})
```
### Locals
Stores variables scoped to the request, making them available only to matching routes. The variables are removed after the request completes. If a stored value implements `io.Closer`, Fiber calls its `Close` method before removal.
:::tip
This is useful if you want to pass some **specific** data to the next middleware. Remember to perform type assertions when retrieving the data to ensure it is of the expected type. You can also use a non-exported type as a key to avoid collisions.
:::
```go title="Signature"
func (c fiber.Ctx) Locals(key any, value ...any) any
```
```go title="Example"
// keyType is an unexported type for keys defined in this package.
// This prevents collisions with keys defined in other packages.
type keyType int
// userKey is the key for user.User values in Contexts. It is
// unexported; clients use user.NewContext and user.FromContext
// instead of using this key directly.
var userKey keyType
app.Use(func(c fiber.Ctx) error {
c.Locals(userKey, "admin") // Stores the string "admin" under a non-exported type key
return c.Next()
})
app.Get("/admin", func(c fiber.Ctx) error {
user, ok := c.Locals(userKey).(string) // Retrieves the data stored under the key and performs a type assertion
if ok && user == "admin" {
return c.Status(fiber.StatusOK).SendString("Welcome, admin!")
}
return c.SendStatus(fiber.StatusForbidden)
})
```
An alternative version of the `Locals` method that takes advantage of Go's generics feature is also available. This version allows for the manipulation and retrieval of local values within a request's context with a more specific data type.
```go title="Signature"
func Locals[V any](c fiber.Ctx, key any, value ...V) V
```
```go title="Example"
app.Use(func(c fiber.Ctx) error {
fiber.Locals[string](c, "john", "doe")
fiber.Locals[int](c, "age", 18)
fiber.Locals[bool](c, "isHuman", true)
return c.Next()
})
app.Get("/test", func(c fiber.Ctx) error {
fiber.Locals[string](c, "john") // "doe"
fiber.Locals[int](c, "age") // 18
fiber.Locals[bool](c, "isHuman") // true
return nil
})
```
Make sure to understand and correctly implement the `Locals` method in both its standard and generic form for better control over route-specific data within your application.
### Matched
Returns `true` if the current request path was matched by the router.
```go title="Signature"
func (c fiber.Ctx) Matched() bool
```
```go title="Example"
app.Use(func(c fiber.Ctx) error {
if c.Matched() {
return c.Next()
}
return c.Status(fiber.StatusNotFound).SendString("Not Found")
})
```
### Next
When **Next** is called, it executes the next method in the stack that matches the current route. You can pass an error struct within the method that will end the chaining and call the [error handler](../guide/error-handling).
```go title="Signature"
func (c fiber.Ctx) Next() error
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
fmt.Println("1st route!")
return c.Next()
})
app.Get("*", func(c fiber.Ctx) error {
fmt.Println("2nd route!")
return c.Next()
})
app.Get("/", func(c fiber.Ctx) error {
fmt.Println("3rd route!")
return c.SendString("Hello, World!")
})
```
### OverrideParam
Overwrites the value of an existing route parameter.
:::note
If the parameter does not exist, this method does nothing.
:::
```go title="Signature"
func (c fiber.Ctx) OverrideParam(name, value string)
```
```go title="Example"
// GET http://example.com/user
app.Get("/user/:name", func(c fiber.Ctx) error {
// mutate parameter
c.OverrideParam("name", "new value")
return c.SendString(c.Params("name")) // sends "new value"
})
// GET http://example.com/shop/tech/1
app.Get("/shop/*", func(c fiber.Ctx) error {
// mutate parameter
c.OverrideParam("*", "new tech") // replaces "tech/1" with "new tech"
return c.SendString(c.Params("*")) // sends "new tech"
})
```
Unnamed route parameters can be accessed by their character (`*` or `+`) followed by their position index (e.g., `*1` for the first wildcard, `*2` for the second).
```go title="Example"
// GET /v1/brand/4/shop/blue/xs
app.Get("/v1/*/shop/*", func(c fiber.Ctx) error {
// mutate parameter
c.OverrideParam("*1", "updated brand")
c.OverrideParam("*2", "updated data")
param1 := c.Params("*1") // "updated brand"
param2 := c.Params("*2") // "updated data"
// ...
})
```
### Redirect
Returns the Redirect reference.
For detailed information, check the [Redirect](./redirect.md) documentation.
```go title="Signature"
func (c fiber.Ctx) Redirect() *Redirect
```
```go title="Example"
app.Get("/coffee", func(c fiber.Ctx) error {
return c.Redirect().To("/teapot")
})
app.Get("/teapot", func(c fiber.Ctx) error {
return c.Status(fiber.StatusTeapot).Send("🍵 short and stout 🍵")
})
```
### Request
Returns the [*fasthttp.Request](https://pkg.go.dev/github.com/valyala/fasthttp#Request) pointer.
```go title="Signature"
func (c fiber.Ctx) Request() *fasthttp.Request
```
:::info
Returns `nil` if the context has been released (e.g., after the handler completes and the context is returned to the pool).
:::
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Request().Header.Method()
// => []byte("GET")
})
```
### RequestCtx
Returns [\*fasthttp.RequestCtx](https://pkg.go.dev/github.com/valyala/fasthttp#RequestCtx) that is compatible with the `context.Context` interface that requires a deadline, a cancellation signal, and other values across API boundaries.
```go title="Signature"
func (c fiber.Ctx) RequestCtx() *fasthttp.RequestCtx
```
:::info
Please read the [Fasthttp Documentation](https://pkg.go.dev/github.com/valyala/fasthttp?tab=doc) for more information.
:::
### Reset
Resets the context fields by the given request when using server handlers.
```go title="Signature"
func (c fiber.Ctx) Reset(fctx *fasthttp.RequestCtx)
```
It is used outside of the Fiber Handlers to reset the context for the next request.
### Response
Returns the [\*fasthttp.Response](https://pkg.go.dev/github.com/valyala/fasthttp#Response) pointer.
```go title="Signature"
func (c fiber.Ctx) Response() *fasthttp.Response
```
:::info
Returns `nil` if the context has been released (e.g., after the handler completes and the context is returned to the pool).
:::
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Response().BodyWriter().Write([]byte("Hello, World!"))
// => "Hello, World!"
return nil
})
```
### RestartRouting
Instead of executing the next method when calling [Next](ctx.md#next), **RestartRouting** restarts execution from the first method that matches the current route. This may be helpful after overriding the path, i.e., an internal redirect. Note that handlers might be executed again, which could result in an infinite loop.
```go title="Signature"
func (c fiber.Ctx) RestartRouting() error
```
```go title="Example"
app.Get("/new", func(c fiber.Ctx) error {
return c.SendString("From /new")
})
app.Get("/old", func(c fiber.Ctx) error {
c.Path("/new")
return c.RestartRouting()
})
```
### Route
Returns the matched [Route](https://pkg.go.dev/github.com/gofiber/fiber?tab=doc#Route) struct.
```go title="Signature"
func (c fiber.Ctx) Route() *Route
```
```go title="Example"
// http://localhost:8080/hello
app.Get("/hello/:name", func(c fiber.Ctx) error {
r := c.Route()
fmt.Println(r.Method, r.Path, r.Params, r.Handlers)
// GET /hello/:name handler [name]
// ...
})
```
:::caution
Do not rely on `c.Route()` in middlewares **before** calling `c.Next()` - `c.Route()` returns the **last executed route**.
:::
```go title="Example"
func MyMiddleware() fiber.Handler {
return func(c fiber.Ctx) error {
beforeNext := c.Route().Path // Will be '/'
err := c.Next()
afterNext := c.Route().Path // Will be '/hello/:name'
return err
}
}
```
### SetContext
Sets the base `context.Context` used by [`Context`](#context). Use this to
propagate deadlines, cancellation signals, or values to asynchronous operations.
```go title="Signature"
func (c fiber.Ctx) SetContext(ctx context.Context)
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.SetContext(context.WithValue(context.Background(), "user", "alice"))
ctx := c.Context()
go doWork(ctx)
return nil
})
```
### String
Returns a unique string representation of the context.
```go title="Signature"
func (c fiber.Ctx) String() string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.String() // => "#0000000100000001 - 127.0.0.1:3000 <-> 127.0.0.1:61516 - GET http://localhost:3000/"
// ...
})
```
### ViewBind
Adds variables to the default view variable map binding to the template engine.
Variables are read by the `Render` method and may be overwritten.
```go title="Signature"
func (c fiber.Ctx) ViewBind(vars Map) error
```
```go title="Example"
app.Use(func(c fiber.Ctx) error {
c.ViewBind(fiber.Map{
"Title": "Hello, World!",
})
return c.Next()
})
app.Get("/", func(c fiber.Ctx) error {
return c.Render("xxx.tmpl", fiber.Map{}) // Render will use the Title variable
})
```
## Request
Methods which operate on the incoming request.
:::tip
Use `c.Req()` to limit gopls suggestions to only these methods!
:::
### AcceptEncoding
Returns the `Accept-Encoding` request header.
```go title="Signature"
func (c fiber.Ctx) AcceptEncoding() string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.AcceptEncoding() // "gzip, br"
return nil
})
```
### AcceptLanguage
Returns the `Accept-Language` request header.
```go title="Signature"
func (c fiber.Ctx) AcceptLanguage() string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.AcceptLanguage() // "en-US,en;q=0.9"
return nil
})
```
### Accepts
Checks if the specified **extensions** or **content** **types** are acceptable.
:::info
Based on the request’s [Accept](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept) HTTP header.
:::
```go title="Signature"
func (c fiber.Ctx) Accepts(offers ...string) string
func (c fiber.Ctx) AcceptsCharsets(offers ...string) string
func (c fiber.Ctx) AcceptsEncodings(offers ...string) string
func (c fiber.Ctx) AcceptsLanguages(offers ...string) string
func (c fiber.Ctx) AcceptsLanguagesExtended(offers ...string) string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Accepts("html") // "html"
c.Accepts("text/html") // "text/html"
c.Accepts("json", "text") // "json"
c.Accepts("application/json") // "application/json"
c.Accepts("text/plain", "application/json") // "application/json", due to quality
c.Accepts("image/png") // ""
c.Accepts("png") // ""
// ...
})
```
```go title="Example 2"
// Accept: text/html, text/*, application/json, */*; q=0
app.Get("/", func(c fiber.Ctx) error {
c.Accepts("text/plain", "application/json") // "application/json", due to specificity
c.Accepts("application/json", "text/html") // "text/html", due to first match
c.Accepts("image/png") // "", due to */* with q=0 is Not Acceptable
// ...
})
```
Media-Type parameters are supported.
```go title="Example 3"
// Accept: text/plain, application/json; version=1; foo=bar
app.Get("/", func(c fiber.Ctx) error {
// Extra parameters in the accept are ignored
c.Accepts("text/plain;format=flowed") // "text/plain;format=flowed"
// An offer must contain all parameters present in the Accept type
c.Accepts("application/json") // ""
// Parameter order and capitalization do not matter. Quotes on values are stripped.
c.Accepts(`application/json;foo="bar";VERSION=1`) // "application/json;foo="bar";VERSION=1"
})
```
```go title="Example 4"
// Accept: text/plain;format=flowed;q=0.9, text/plain
// i.e., "I prefer text/plain;format=flowed less than other forms of text/plain"
app.Get("/", func(c fiber.Ctx) error {
// Beware: the order in which offers are listed matters.
// Although the client specified they prefer not to receive format=flowed,
// the text/plain Accept matches with "text/plain;format=flowed" first, so it is returned.
c.Accepts("text/plain;format=flowed", "text/plain") // "text/plain;format=flowed"
// Here, things behave as expected:
c.Accepts("text/plain", "text/plain;format=flowed") // "text/plain"
})
```
Fiber provides similar functions for the other accept headers.
For `Accept-Language`, Fiber uses the [Basic Filtering](https://www.rfc-editor.org/rfc/rfc4647#section-3.3.1) algorithm. A language range matches an offer only if it exactly equals the tag or is a prefix followed by a hyphen. For example, the range `en` matches `en-US`, but `en-US` does not match `en`.
`AcceptsLanguagesExtended` applies [Extended Filtering](https://www.rfc-editor.org/rfc/rfc4647#section-3.3.2) where `*` may match zero or more subtags and wildcard matches can slide across subtags unless blocked by a singleton like `x`.
```go
// Accept-Charset: utf-8, iso-8859-1;q=0.2
// Accept-Encoding: gzip, compress;q=0.2
// Accept-Language: en;q=0.8, nl, ru
app.Get("/", func(c fiber.Ctx) error {
c.AcceptsCharsets("utf-16", "iso-8859-1")
// "iso-8859-1"
c.AcceptsEncodings("compress", "br")
// "compress"
c.AcceptsLanguages("pt", "nl", "ru")
// "nl"
c.AcceptsLanguagesExtended("en-US", "fr-CA")
// depends on extended ranges in the request header
// ...
})
```
### AcceptsEventStream
Returns `true` when the `Accept` header allows `text/event-stream`.
```go title="Signature"
func (c fiber.Ctx) AcceptsEventStream() bool
```
```go title="Example"
// Accept: text/html, application/json;q=0.9
app.Get("/", func(c fiber.Ctx) error {
c.AcceptsEventStream() // false
return nil
})
```
### AcceptsHTML
Returns `true` when the `Accept` header allows HTML.
```go title="Signature"
func (c fiber.Ctx) AcceptsHTML() bool
```
```go title="Example"
// Accept: text/html, application/json;q=0.9
app.Get("/", func(c fiber.Ctx) error {
c.AcceptsHTML() // true
return nil
})
```
### AcceptsJSON
Returns `true` when the `Accept` header allows JSON.
```go title="Signature"
func (c fiber.Ctx) AcceptsJSON() bool
```
```go title="Example"
// Accept: text/html, application/json;q=0.9
app.Get("/", func(c fiber.Ctx) error {
c.AcceptsJSON() // true
return nil
})
```
### AcceptsXML
Returns `true` when the `Accept` header allows XML.
```go title="Signature"
func (c fiber.Ctx) AcceptsXML() bool
```
```go title="Example"
// Accept: text/html, application/json;q=0.9
app.Get("/", func(c fiber.Ctx) error {
c.AcceptsXML() // false
return nil
})
```
### BaseURL
Returns the base URL (**protocol** + **host**) as a `string`.
```go title="Signature"
func (c fiber.Ctx) BaseURL() string
```
```go title="Example"
// GET https://example.com/page#chapter-1
app.Get("/", func(c fiber.Ctx) error {
c.BaseURL() // "https://example.com"
// ...
})
```
### Body
As per the header `Content-Encoding`, this method will try to perform a file decompression from the **body** bytes. In case no `Content-Encoding` header is sent (or when it is set to `identity`), it will perform as [BodyRaw](#bodyraw). If an unknown or unsupported encoding is encountered, the response status will be `415 Unsupported Media Type` or `501 Not Implemented`.
```go title="Signature"
func (c fiber.Ctx) Body() []byte
```
```go title="Example"
// echo 'user=john' | gzip | curl -v -i --data-binary @- -H "Content-Encoding: gzip" http://localhost:8080
app.Post("/", func(c fiber.Ctx) error {
// Decompress body from POST request based on the Content-Encoding and return the raw content:
return c.Send(c.Body()) // []byte("user=john")
})
```
:::info
The returned value is valid only within the handler. Do not store references.
Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation)
:::
### BodyRaw
Returns the raw request **body**.
```go title="Signature"
func (c fiber.Ctx) BodyRaw() []byte
```
```go title="Example"
// curl -X POST http://localhost:8080 -d user=john
app.Post("/", func(c fiber.Ctx) error {
// Get raw body from POST request:
return c.Send(c.BodyRaw()) // []byte("user=john")
})
```
:::info
The returned value is valid only within the handler. Do not store references.
Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation)
:::
### Charset
Returns the `charset` parameter from the `Content-Type` header.
```go title="Signature"
func (c fiber.Ctx) Charset() string
```
```go title="Example"
// Content-Type: application/json; charset=utf-8
app.Post("/", func(c fiber.Ctx) error {
c.Charset() // "utf-8"
return nil
})
```
### ClientHelloInfo
`ClientHelloInfo` contains information from a ClientHello message to guide application logic in the `GetCertificate` and `GetConfigForClient` callbacks.
Refer to the [ClientHelloInfo](https://golang.org/pkg/crypto/tls/#ClientHelloInfo) struct documentation for details on the returned struct.
```go title="Signature"
func (c fiber.Ctx) ClientHelloInfo() *tls.ClientHelloInfo
```
```go title="Example"
// GET http://example.com/hello
app.Get("/hello", func(c fiber.Ctx) error {
chi := c.ClientHelloInfo()
// ...
})
```
### Cookies
Gets a cookie value by key. You can pass an optional default value that will be returned if the cookie key does not exist.
```go title="Signature"
func (c fiber.Ctx) Cookies(key string, defaultValue ...string) string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
// Get cookie by key:
c.Cookies("name") // "john"
c.Cookies("empty", "doe") // "doe"
// ...
})
```
:::info
The returned value is valid only within the handler. Do not store references.
Use [`App.GetString`](./app.md#getstring) or [`App.GetBytes`](./app.md#getbytes) when immutability is enabled, or manually copy values (for example with [`utils.CopyString`](https://github.com/gofiber/utils) / `utils.CopyBytes`) when it's disabled. [Read more...](../#zero-allocation)
:::
### FormFile
MultipartForm files can be retrieved by name, the **first** file from the given key is returned.
```go title="Signature"
func (c fiber.Ctx) FormFile(key string) (*multipart.FileHeader, error)
```
```go title="Example"
app.Post("/", func(c fiber.Ctx) error {
// Get first file from form field "document":
file, err := c.FormFile("document")
// Save file to root directory:
return c.SaveFile(file, fmt.Sprintf("./%s", file.Filename))
})
```
### FormValue
Form values can be retrieved by name, the **first** value for the given key is returned.
```go title="Signature"
func (c fiber.Ctx) FormValue(key string, defaultValue ...string) string
```
```go title="Example"
app.Post("/", func(c fiber.Ctx) error {
// Get first value from form field "name":
c.FormValue("name")
// => "john" or "" if not exist
// ..
})
```
:::info
The returned value is valid only within the handler. Do not store references.
Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation)
:::
### Fresh
When the response is still **fresh** in the client's cache **true** is returned; otherwise, **false** is returned to indicate that the client cache is now stale and the full response should be sent.
When a client sends the Cache-Control: no-cache request header to indicate an end-to-end reload request, `Fresh` will return false to make handling these requests transparent.
Read more on [https://expressjs.com/en/4x/api.html\#req.fresh](https://expressjs.com/en/4x/api.html#req.fresh)
```go title="Signature"
func (c fiber.Ctx) Fresh() bool
```
### FullURL
Returns the full request URL (protocol + host + original URL).
```go title="Signature"
func (c fiber.Ctx) FullURL() string
```
```go title="Example"
// GET http://example.com/search?q=fiber
app.Get("/", func(c fiber.Ctx) error {
c.FullURL() // "http://example.com/search?q=fiber"
return nil
})
```
### Get
Returns the HTTP request header specified by the field.
:::tip
The match is **case-insensitive**.
:::
```go title="Signature"
func (c fiber.Ctx) Get(key string, defaultValue ...string) string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Get("Content-Type") // "text/plain"
c.Get("CoNtEnT-TypE") // "text/plain"
c.Get("something", "john") // "john"
// ..
})
```
:::info
The returned value is valid only within the handler. Do not store references.
Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation)
:::
### HasHeader
Reports whether the request includes a header with the given key.
```go title="Signature"
func (c fiber.Ctx) HasHeader(key string) bool
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.HasHeader("X-Trace-Id")
return nil
})
```
### Host
Returns the host derived from the [Host](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) HTTP header.
In a network context, [`Host`](#host) refers to the combination of a hostname and potentially a port number used for connecting, while [`Hostname`](#hostname) refers specifically to the name assigned to a device on a network, excluding any port information.
```go title="Signature"
func (c fiber.Ctx) Host() string
```
```go title="Example"
// GET http://google.com:8080/search
app.Get("/", func(c fiber.Ctx) error {
c.Host() // "google.com:8080"
c.Hostname() // "google.com"
// ...
})
```
:::info
The returned value is valid only within the handler. Do not store references.
Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation)
:::
### Hostname
Returns the hostname derived from the [Host](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) HTTP header.
```go title="Signature"
func (c fiber.Ctx) Hostname() string
```
```go title="Example"
// GET http://google.com/search
app.Get("/", func(c fiber.Ctx) error {
c.Hostname() // "google.com"
// ...
})
```
:::info
The returned value is valid only within the handler. Do not store references.
Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation)
:::
### IP
Returns the remote IP address of the request.
```go title="Signature"
func (c fiber.Ctx) IP() string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.IP() // "127.0.0.1"
// ...
})
```
:::info
By default, `c.IP()` returns the remote IP address from the TCP connection. When your Fiber app is behind a reverse proxy (like Nginx, Traefik, or a load balancer), you need to configure **both** [`TrustProxy`](fiber.md#trustproxy) and [`ProxyHeader`](fiber.md#proxyheader) to read the client IP from proxy headers like `X-Forwarded-For`.
**Important:** You must enable `TrustProxy` and configure trusted proxy IPs to prevent header spoofing. Simply setting `ProxyHeader` alone will not work.
**Note:** When using a proxy header such as `X-Forwarded-For`, `c.IP()` returns the raw header value unless [`EnableIPValidation`](fiber.md#enableipvalidation) is enabled. For `X-Forwarded-For`, this raw value may be a comma-separated list of IPs; enable `EnableIPValidation` if you need `c.IP()` to return a single, validated client IP.
:::
#### Configuration for apps behind a reverse proxy
```go title="Example - Basic Configuration"
app := fiber.New(fiber.Config{
// Enable proxy support
TrustProxy: true,
// Specify which header contains the real client IP
ProxyHeader: fiber.HeaderXForwardedFor,
// Configure which proxy IPs to trust
TrustProxyConfig: fiber.TrustProxyConfig{
// Trust private IP ranges (for internal load balancers)
Private: true,
// Or specify exact proxy IPs/ranges
// Proxies: []string{"10.10.0.58", "192.168.0.0/24"},
},
})
```
```go title="Example - Specific Proxy IPs"
app := fiber.New(fiber.Config{
TrustProxy: true,
ProxyHeader: fiber.HeaderXForwardedFor,
TrustProxyConfig: fiber.TrustProxyConfig{
// Trust only specific proxy IP addresses
Proxies: []string{"10.10.0.58", "192.168.1.0/24"},
},
})
```
See [`TrustProxy`](fiber.md#trustproxy) and [`TrustProxyConfig`](fiber.md#trustproxyconfig) for more details on security considerations and configuration options.
### IPs
Returns an array of IP addresses specified in the [X-Forwarded-For](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) request header.
```go title="Signature"
func (c fiber.Ctx) IPs() []string
```
```go title="Example"
// X-Forwarded-For: proxy1, 127.0.0.1, proxy3
app.Get("/", func(c fiber.Ctx) error {
c.IPs() // ["proxy1", "127.0.0.1", "proxy3"]
// ...
})
```
:::caution
Improper use of the X-Forwarded-For header can be a security risk. For details, see the [Security and privacy concerns](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For#security_and_privacy_concerns) section.
:::
### Is
Returns the matching **content type**, if the incoming request’s [Content-Type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) HTTP header field matches the [MIME type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types) specified by the type parameter.
:::info
If the request has **no** body, it returns **false**.
:::
```go title="Signature"
func (c fiber.Ctx) Is(extension string) bool
```
```go title="Example"
// Content-Type: text/html; charset=utf-8
app.Get("/", func(c fiber.Ctx) error {
c.Is("html") // true
c.Is(".html") // true
c.Is("json") // false
// ...
})
```
### IsForm
Reports whether the `Content-Type` header is form-encoded.
```go title="Signature"
func (c fiber.Ctx) IsForm() bool
```
```go title="Example"
// Content-Type: application/x-www-form-urlencoded
app.Post("/", func(c fiber.Ctx) error {
c.IsForm() // true
return nil
})
```
### IsFromLocal
Returns `true` if the request came from localhost.
```go title="Signature"
func (c fiber.Ctx) IsFromLocal() bool
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
// If request came from localhost, return true; else return false
c.IsFromLocal()
// ...
})
```
### IsJSON
Reports whether the `Content-Type` header is JSON.
```go title="Signature"
func (c fiber.Ctx) IsJSON() bool
```
```go title="Example"
// Content-Type: application/json; charset=utf-8
app.Post("/", func(c fiber.Ctx) error {
c.IsJSON() // true
return nil
})
```
### IsMultipart
Reports whether the `Content-Type` header is multipart form data.
```go title="Signature"
func (c fiber.Ctx) IsMultipart() bool
```
```go title="Example"
// Content-Type: multipart/form-data; boundary=abc123
app.Post("/", func(c fiber.Ctx) error {
c.IsMultipart() // true
return nil
})
```
### IsProxyTrusted
Checks the trustworthiness of the remote IP.
If [`TrustProxy`](fiber.md#trustproxy) is `false`, it returns `true`.
`IsProxyTrusted` can check the remote IP by proxy ranges and IP map.
```go title="Signature"
func (c fiber.Ctx) IsProxyTrusted() bool
```
```go title="Example"
app := fiber.New(fiber.Config{
// TrustProxy enables the trusted proxy check
TrustProxy: true,
// TrustProxyConfig allows for configuring trusted proxies.
// Proxies is a list of trusted proxy IP ranges/addresses
TrustProxyConfig: fiber.TrustProxyConfig{
Proxies: []string{"0.8.0.0", "1.1.1.1/30"}, // IP address or IP address range
Loopback: true, // Trust loopback addresses (127.0.0.0/8, ::1/128)
UnixSocket: true, // Trust Unix domain socket connections
},
})
app.Get("/", func(c fiber.Ctx) error {
// If request came from trusted proxy, return true; else return false
c.IsProxyTrusted()
// ...
})
```
### MediaType
Returns the MIME type from the `Content-Type` header without parameters.
```go title="Signature"
func (c fiber.Ctx) MediaType() string
```
```go title="Example"
// Content-Type: application/json; charset=utf-8
app.Post("/", func(c fiber.Ctx) error {
c.MediaType() // "application/json"
return nil
})
```
### Method
Returns a string corresponding to the HTTP method of the request: `GET`, `POST`, `PUT`, and so on.
Optionally, you can override the method by passing a string.
```go title="Signature"
func (c fiber.Ctx) Method(override ...string) string
```
```go title="Example"
app.Post("/override", func(c fiber.Ctx) error {
c.Method() // "POST"
c.Method("GET")
c.Method() // "GET"
// ...
})
```
### MultipartForm
To access multipart form entries, you can parse the binary with `MultipartForm()`. This returns a `*multipart.Form`, allowing you to access form values and files.
```go title="Signature"
func (c fiber.Ctx) MultipartForm() (*multipart.Form, error)
```
```go title="Example"
app.Post("/", func(c fiber.Ctx) error {
// Parse the multipart form:
if form, err := c.MultipartForm(); err == nil {
// => *multipart.Form
if token := form.Value["token"]; len(token) > 0 {
// Get key value:
fmt.Println(token[0])
}
// Get all files from "documents" key:
files := form.File["documents"]
// => []*multipart.FileHeader
// Loop through files:
for _, file := range files {
fmt.Println(file.Filename, file.Size, file.Header["Content-Type"][0])
// => "tutorial.pdf" 360641 "application/pdf"
// Save the files to disk:
if err := c.SaveFile(file, fmt.Sprintf("./%s", file.Filename)); err != nil {
return err
}
}
}
return nil
})
```
### OriginalURL
Returns the original request URL.
```go title="Signature"
func (c fiber.Ctx) OriginalURL() string
```
```go title="Example"
// GET http://example.com/search?q=something
app.Get("/", func(c fiber.Ctx) error {
c.OriginalURL() // "/search?q=something"
// ...
})
```
:::info
The returned value is valid only within the handler. Do not store references.
Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation)
:::
### Params
This method can be used to get the route parameters. You can pass an optional default value that will be returned if the param key does not exist.
:::info
Defaults to an empty string \(`""`\) if the param **doesn't** exist.
:::
```go title="Signature"
func (c fiber.Ctx) Params(key string, defaultValue ...string) string
```
```go title="Example"
// GET http://example.com/user/fenny
app.Get("/user/:name", func(c fiber.Ctx) error {
c.Params("name") // "fenny"
// ...
})
// GET http://example.com/user/fenny/123
app.Get("/user/*", func(c fiber.Ctx) error {
c.Params("*") // "fenny/123"
c.Params("*1") // "fenny/123"
// ...
})
```
Unnamed route parameters \(\*, +\) can be fetched by the **character** and the **counter** in the route.
```go title="Example"
// ROUTE: /v1/*/shop/*
// GET: /v1/brand/4/shop/blue/xs
c.Params("*1") // "brand/4"
c.Params("*2") // "blue/xs"
```
For reasons of **downward compatibility**, the first parameter segment for the parameter character can also be accessed without the counter.
```go title="Example"
app.Get("/v1/*/shop/*", func(c fiber.Ctx) error {
c.Params("*") // outputs the value of the first wildcard segment
})
```
:::info
The returned value is valid only within the handler. Do not store references.
Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation)
:::
In certain scenarios, it can be useful to have an alternative approach to handle different types of parameters, not
just strings. This can be achieved using a generic `Params` function known as `Params[V GenericType](c fiber.Ctx, key string, defaultValue ...V) V`.
This function is capable of parsing a route parameter and returning a value of a type that is assumed and specified by `V GenericType`.
```go title="Signature"
func Params[V GenericType](c fiber.Ctx, key string, defaultValue ...V) V
```
```go title="Example"
// GET http://example.com/user/114
app.Get("/user/:id", func(c fiber.Ctx) error{
fiber.Params[string](c, "id") // returns "114" as string.
fiber.Params[int](c, "id") // returns 114 as integer
fiber.Params[string](c, "number") // returns "" (default string type)
fiber.Params[int](c, "number") // returns 0 (default integer value type)
})
```
The generic `Params` function supports returning the following data types based on `V GenericType`:
- Integer: `int`, `int8`, `int16`, `int32`, `int64`
- Unsigned integer: `uint`, `uint8`, `uint16`, `uint32`, `uint64`
- Floating-point numbers: `float32`, `float64`
- Boolean: `bool`
- String: `string`
- Byte array: `[]byte`
### Path
Contains the path part of the request URL. Optionally, you can override the path by passing a string. For internal redirects, you might want to call [RestartRouting](ctx.md#restartrouting) instead of [Next](ctx.md#next).
```go title="Signature"
func (c fiber.Ctx) Path(override ...string) string
```
```go title="Example"
// GET http://example.com/users?sort=desc
app.Get("/users", func(c fiber.Ctx) error {
c.Path() // "/users"
c.Path("/john")
c.Path() // "/john"
// ...
})
```
### Port
Returns the remote port of the request.
```go title="Signature"
func (c fiber.Ctx) Port() string
```
```go title="Example"
// GET http://example.com:8080
app.Get("/", func(c fiber.Ctx) error {
c.Port() // "8080"
// ...
})
```
### Protocol
Contains the request protocol string: `http` or `https` for **TLS** requests.
```go title="Signature"
func (c fiber.Ctx) Protocol() string
```
```go title="Example"
// GET http://example.com
app.Get("/", func(c fiber.Ctx) error {
c.Protocol() // "http"
// ...
})
```
### Queries
`Queries` is a function that returns an object containing a property for each query string parameter in the route.
```go title="Signature"
func (c fiber.Ctx) Queries() map[string]string
```
```go title="Example"
// GET http://example.com/?name=alex&want_pizza=false&id=
app.Get("/", func(c fiber.Ctx) error {
m := c.Queries()
m["name"] // "alex"
m["want_pizza"] // "false"
m["id"] // ""
// ...
})
```
```go title="Example"
// GET http://example.com/?field1=value1&field1=value2&field2=value3
app.Get("/", func (c fiber.Ctx) error {
m := c.Queries()
m["field1"] // "value2"
m["field2"] // "value3"
})
```
```go title="Example"
// GET http://example.com/?list_a=1&list_a=2&list_a=3&list_b[]=1&list_b[]=2&list_b[]=3&list_c=1,2,3
app.Get("/", func(c fiber.Ctx) error {
m := c.Queries()
m["list_a"] // "3"
m["list_b[]"] // "3"
m["list_c"] // "1,2,3"
})
```
```go title="Example"
// GET /api/posts?filters.author.name=John&filters.category.name=Technology
app.Get("/", func(c fiber.Ctx) error {
m := c.Queries()
m["filters.author.name"] // John
m["filters.category.name"] // Technology
})
```
```go title="Example"
// GET /api/posts?tags=apple,orange,banana&filters[tags]=apple,orange,banana&filters[category][name]=fruits&filters.tags=apple,orange,banana&filters.category.name=fruits
app.Get("/", func(c fiber.Ctx) error {
m := c.Queries()
m["tags"] // apple,orange,banana
m["filters[tags]"] // apple,orange,banana
m["filters[category][name]"] // fruits
m["filters.tags"] // apple,orange,banana
m["filters.category.name"] // fruits
})
```
### Query
This method returns a string corresponding to a query string parameter by name. You can pass an optional default value that will be returned if the query key does not exist.
:::info
If there is **no** query string, it returns an **empty string**.
:::
```go title="Signature"
func (c fiber.Ctx) Query(key string, defaultValue ...string) string
```
```go title="Example"
// GET http://example.com/?order=desc&brand=nike
app.Get("/", func(c fiber.Ctx) error {
c.Query("order") // "desc"
c.Query("brand") // "nike"
c.Query("empty", "nike") // "nike"
// ...
})
```
:::info
The returned value is valid only within the handler. Do not store references.
Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation)
:::
In certain scenarios, it can be useful to have an alternative approach to handle different types of query parameters, not
just strings. This can be achieved using a generic `Query` function known as `Query[V GenericType](c fiber.Ctx, key string, defaultValue ...V) V`.
This function is capable of parsing a query string and returning a value of a type that is assumed and specified by `V GenericType`.
Here is the signature for the generic `Query` function:
```go title="Signature"
func Query[V GenericType](c fiber.Ctx, key string, defaultValue ...V) V
```
```go title="Example"
// GET http://example.com/?page=1&brand=nike&new=true
app.Get("/", func(c fiber.Ctx) error {
fiber.Query[int](c, "page") // 1
fiber.Query[string](c, "brand") // "nike"
fiber.Query[bool](c, "new") // true
// ...
})
```
In this case, `Query[V GenericType](c Ctx, key string, defaultValue ...V) V` can retrieve `page` as an integer, `brand` as a string, and `new` as a boolean. The function uses the appropriate parsing function for each specified type to ensure the correct type is returned. This simplifies the retrieval process of different types of query parameters, making your controller actions cleaner.
The generic `Query` function supports returning the following data types based on `V GenericType`:
- Integer: `int`, `int8`, `int16`, `int32`, `int64`
- Unsigned integer: `uint`, `uint8`, `uint16`, `uint32`, `uint64`
- Floating-point numbers: `float32`, `float64`
- Boolean: `bool`
- String: `string`
- Byte array: `[]byte`
### Range
Returns a struct containing the type and a slice of ranges.
Only the canonical `bytes` unit is recognized and any optional
whitespace around range specifiers will be ignored, as specified
in RFC 9110.
If none of the requested ranges are satisfiable, the method automatically
sets the HTTP status code to **416 Range Not Satisfiable** and populates the
`Content-Range` header with the current representation size.
```go title="Signature"
func (c fiber.Ctx) Range(size int64) (Range, error)
```
```go title="Example"
// Range: bytes=500-700, 700-900
app.Get("/", func(c fiber.Ctx) error {
r := c.Range(1000)
if r.Type == "bytes" {
for _, rng := range r.Ranges {
fmt.Println(rng)
// [500, 700]
}
}
})
```
### Referer
Returns the `Referer` request header.
```go title="Signature"
func (c fiber.Ctx) Referer() string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Referer() // "https://example.com"
return nil
})
```
### RequestID
```go title="Signature"
func (c fiber.Ctx) RequestID() string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.RequestID() // "8d7ad5e3-aaf3-450b-a241-2beb887efd54"
return nil
})
```
### SaveFile
Method is used to save **any** multipart file to disk.
```go title="Signature"
func (c fiber.Ctx) SaveFile(fh *multipart.FileHeader, path string) error
```
```go title="Example"
app.Post("/", func(c fiber.Ctx) error {
// Parse the multipart form:
if form, err := c.MultipartForm(); err == nil {
// => *multipart.Form
// Get all files from "documents" key:
files := form.File["documents"]
// => []*multipart.FileHeader
// Loop through files:
for _, file := range files {
fmt.Println(file.Filename, file.Size, file.Header["Content-Type"][0])
// => "tutorial.pdf" 360641 "application/pdf"
// Save the files to disk:
if err := c.SaveFile(file, fmt.Sprintf("./%s", file.Filename)); err != nil {
return err
}
}
return err
}
})
```
### SaveFileToStorage
Method is used to save **any** multipart file to an external storage system.
```go title="Signature"
func (c fiber.Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error
```
```go title="Example"
storage := memory.New()
app.Post("/", func(c fiber.Ctx) error {
// Parse the multipart form:
if form, err := c.MultipartForm(); err == nil {
// => *multipart.Form
// Get all files from "documents" key:
files := form.File["documents"]
// => []*multipart.FileHeader
// Loop through files:
for _, file := range files {
fmt.Println(file.Filename, file.Size, file.Header["Content-Type"][0])
// => "tutorial.pdf" 360641 "application/pdf"
// Save the files to storage:
if err := c.SaveFileToStorage(file, fmt.Sprintf("./%s", file.Filename), storage); err != nil {
return err
}
}
return err
}
})
```
### Schema
Contains the request protocol string: `http` or `https` for TLS requests.
:::info
Please use [`Config.TrustProxy`](fiber.md#trustproxy) to prevent header spoofing if your app is behind a proxy.
:::
```go title="Signature"
func (c fiber.Ctx) Schema() string
```
```go title="Example"
// GET http://example.com
app.Get("/", func(c fiber.Ctx) error {
c.Schema() // "http"
// ...
})
```
### Secure
A boolean property that is `true` if a **TLS** connection is established.
```go title="Signature"
func (c fiber.Ctx) Secure() bool
```
```go title="Example"
// Secure() method is equivalent to:
c.Protocol() == "https"
```
### Stale
When the client's cached response is **stale**, this method returns **true**. It
is the logical complement of [`Fresh`](#fresh), which checks whether the cached
representation is still valid.
[https://expressjs.com/en/4x/api.html#req.stale](https://expressjs.com/en/4x/api.html#req.stale)
```go title="Signature"
func (c fiber.Ctx) Stale() bool
```
### Subdomains
Returns a slice with the host’s sub-domain labels. The dot-separated parts that precede the registrable domain (`example`) and the top-level domain (ex: `com`).
The `subdomain offset` (default `2`) tells Fiber how many labels, counting from the right-hand side, are always discarded.
Passing an `offset` argument lets you override that value for a single call.
```go
func (c fiber.Ctx) Subdomains(offset ...int) []string
```
| `offset` | Result | Meaning |
| ---------------------- | --------------------------------------- | --------------------------------------------- |
| *omitted* → **2** | trim 2 right-most labels | drop the registrable domain **and** the TLD |
| `1` to `len(labels)-1` | trim exactly `offset` right-most labels | custom trimming of available labels |
| `>= len(labels)` | **return `[]`** | offset exceeds available labels → empty slice |
| `0` | **return every label** | keep the entire host unchanged |
| `< 0` | **return `[]`** | negative offsets are invalid → empty slice |
#### Example
```go
// Host: "tobi.ferrets.example.com"
app.Get("/", func(c fiber.Ctx) error {
c.Subdomains() // ["tobi", "ferrets"]
c.Subdomains(1) // ["tobi", "ferrets", "example"]
c.Subdomains(0) // ["tobi", "ferrets", "example", "com"]
c.Subdomains(-1) // []
// ...
})
```
### UserAgent
Returns the `User-Agent` request header.
```go title="Signature"
func (c fiber.Ctx) UserAgent() string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.UserAgent() // "Mozilla/5.0 ..."
return nil
})
```
### XHR
A boolean property that is `true` if the request’s [X-Requested-With](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers) header field is [XMLHttpRequest](https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest), indicating that the request was issued by a client library (such as [jQuery](https://api.jquery.com/jQuery.ajax/)).
```go title="Signature"
func (c fiber.Ctx) XHR() bool
```
```go title="Example"
// X-Requested-With: XMLHttpRequest
app.Get("/", func(c fiber.Ctx) error {
c.XHR() // true
// ...
})
```
## Response
Methods which modify the response object.
:::tip
Use `c.Res()` to limit gopls suggestions to only these methods!
:::
### Append
Appends the specified **value** to the HTTP response header field.
:::caution
If the header is **not** already set, it creates the header with the specified value.
:::
```go title="Signature"
func (c fiber.Ctx) Append(field string, values ...string)
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Append("Link", "http://google.com", "http://localhost")
// => Link: http://google.com, http://localhost
c.Append("Link", "Test")
// => Link: http://google.com, http://localhost, Test
// ...
})
```
### Attachment
Sets the HTTP response [Content-Disposition](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header field to `attachment`.
```go title="Signature"
func (c fiber.Ctx) Attachment(filename ...string)
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Attachment()
// => Content-Disposition: attachment
c.Attachment("./upload/images/logo.png")
// => Content-Disposition: attachment; filename="logo.png"
// => Content-Type: image/png
// ...
})
```
Non-ASCII filenames are encoded using the `filename*` parameter as defined in
[RFC 6266](https://www.rfc-editor.org/rfc/rfc6266) and
[RFC 8187](https://www.rfc-editor.org/rfc/rfc8187):
```go title="Example"
app.Get("/non-ascii", func(c fiber.Ctx) error {
c.Attachment("./files/文件.txt")
// => Content-Disposition: attachment; filename="文件.txt"; filename*=UTF-8''%E6%96%87%E4%BB%B6.txt
return nil
})
```
### AutoFormat
Performs content-negotiation on the [Accept](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept) HTTP header. It uses [Accepts](ctx.md#accepts) to select a proper format.
The supported content types are `text/html`, `text/plain`, `application/json`, `application/vnd.msgpack`, `application/xml`, and `application/cbor`.
For more flexible content negotiation, use [Format](ctx.md#format).
:::info
If the header is **not** specified or there is **no** proper format, **text/plain** is used.
:::
```go title="Signature"
func (c fiber.Ctx) AutoFormat(body any) error
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
// Accept: text/plain
c.AutoFormat("Hello, World!")
// => Hello, World!
// Accept: text/html
c.AutoFormat("Hello, World!")
// =>
Hello, World!
type User struct {
Name string
}
user := User{"John Doe"}
// Accept: application/json
c.AutoFormat(user)
// => {"Name":"John Doe"}
// Accept: application/vnd.msgpack
c.AutoFormat(user)
// => 82 a4 6e 61 6d 65 a4 6a 6f 68 6e a4 70 61 73 73 a3 64 6f 65
// Accept: application/cbor
c.AutoFormat(user)
// => a1 64 4e 61 6d 65 68 4a 6f 68 6e 20 44 6f 65
// Accept: application/xml
c.AutoFormat(user)
// => John Doe
// ..
})
```
### CBOR
CBOR converts any interface or string to CBOR encoded bytes.
> **Note:** Before using any CBOR-related features, make sure to follow the [CBOR setup instructions](../guide/advance-format.md#cbor).
:::info
CBOR also sets the content header to the `ctype` parameter. If no `ctype` is passed in, the header is set to `application/cbor`.
:::
```go title="Signature"
func (c fiber.Ctx) CBOR(data any, ctype ...string) error
```
```go title="Example"
type SomeStruct struct {
Name string `cbor:"name"`
Age uint8 `cbor:"age"`
}
app.Get("/cbor", func(c fiber.Ctx) error {
// Create data struct:
data := SomeStruct{
Name: "Grame",
Age: 20,
}
return c.CBOR(data)
// => Content-Type: application/cbor
// => \xa2dnameeGramecage\x14
return c.CBOR(fiber.Map{
"name": "Grame",
"age": 20,
})
// => Content-Type: application/cbor
// => \xa2dnameeGramecage\x14
return c.CBOR(fiber.Map{
"type": "https://example.com/probs/out-of-credit",
"title": "You do not have enough credit.",
"status": 403,
"detail": "Your current balance is 30, but that costs 50.",
"instance": "/account/12345/msgs/abc",
})
// => Content-Type: application/cbor
// => \xa5dtypex'https://example.com/probs/out-of-creditetitlex\x1eYou do not have enough credit.fstatus\x19\x01\x93fdetailx.Your current balance is 30, but that costs 50.hinstancew/account/12345/msgs/abc
})
```
### ClearCookie
Expires a client cookie (or all cookies if left empty).
```go title="Signature"
func (c fiber.Ctx) ClearCookie(key ...string)
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
// Clears all cookies:
c.ClearCookie()
// Expire specific cookie by name:
c.ClearCookie("user")
// Expire multiple cookies by names:
c.ClearCookie("token", "session", "track_id", "version")
// ...
})
```
:::caution
Web browsers and other compliant clients will only clear the cookie if the given options are identical to those when creating the cookie, excluding `Expires` and `MaxAge`. `ClearCookie` will not set these values for you - a technique similar to the one shown below should be used to ensure your cookie is deleted.
:::
```go title="Example"
app.Get("/set", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "token",
Value: "randomvalue",
Expires: time.Now().Add(24 * time.Hour),
HTTPOnly: true,
SameSite: "Lax",
})
// ...
})
app.Get("/delete", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "token",
Expires: fasthttp.CookieExpireDelete, // Use fasthttp's built-in constant
HTTPOnly: true,
SameSite: "Lax",
})
// ...
})
```
You can also use `c.Cookie()` to expire cookies with specific `Path` or `Domain` attributes:
```go title="Example"
app.Get("/logout", func(c fiber.Ctx) error {
// Expire a cookie with path and domain
c.Cookie(&fiber.Cookie{
Name: "token",
Path: "/api",
Domain: "example.com",
Expires: fasthttp.CookieExpireDelete,
})
return c.SendStatus(fiber.StatusOK)
})
```
### Cookie
Sets a cookie.
```go title="Signature"
func (c fiber.Ctx) Cookie(cookie *Cookie)
```
```go
type Cookie struct {
Name string `json:"name"` // The name of the cookie
Value string `json:"value"` // The value of the cookie
Path string `json:"path"` // Specifies a URL path which is allowed to receive the cookie
Domain string `json:"domain"` // Specifies the domain which is allowed to receive the cookie
MaxAge int `json:"max_age"` // The maximum age (in seconds) of the cookie
Expires time.Time `json:"expires"` // The expiration date of the cookie
Secure bool `json:"secure"` // Indicates that the cookie should only be transmitted over a secure HTTPS connection
HTTPOnly bool `json:"http_only"` // Indicates that the cookie is accessible only through the HTTP protocol
SameSite string `json:"same_site"` // Controls whether or not a cookie is sent with cross-site requests
Partitioned bool `json:"partitioned"` // Indicates if the cookie is stored in a partitioned cookie jar
SessionOnly bool `json:"session_only"` // Indicates if the cookie is a session-only cookie
}
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
// Create cookie
cookie := new(fiber.Cookie)
cookie.Name = "john"
cookie.Value = "doe"
cookie.Expires = time.Now().Add(24 * time.Hour)
// Set cookie
c.Cookie(cookie)
// ...
})
```
:::info
When setting a cookie with `SameSite=None`, Fiber automatically sets `Secure=true` as required by RFC 6265bis and modern browsers. This ensures compliance with the "None" SameSite policy which mandates that cookies must be sent over secure connections.
For more information, see:
- [Mozilla Documentation](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#none)
- [Chrome Documentation](https://developers.google.com/search/blog/2020/01/get-ready-for-new-samesitenone-secure)
:::
:::info
Partitioned cookies allow partitioning the cookie jar by top-level site, enhancing user privacy by preventing cookies from being shared across different sites. This feature is particularly useful in scenarios where a user interacts with embedded third-party services that should not have access to the main site's cookies. You can check out [CHIPS](https://developers.google.com/privacy-sandbox/3pcd/chips) for more information.
:::
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
// Create a new partitioned cookie
cookie := new(fiber.Cookie)
cookie.Name = "user_session"
cookie.Value = "abc123"
cookie.Partitioned = true // This cookie will be stored in a separate jar when it's embedded into another website
// Set the cookie in the response
c.Cookie(cookie)
return c.SendString("Partitioned cookie set")
})
```
### Download
Transfers the file from the given path as an `attachment`.
Typically, browsers will prompt the user to download. By default, the [Content-Disposition](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header `filename=` parameter is the file path (this typically appears in the browser dialog).
Override this default with the `filename` parameter.
```go title="Signature"
func (c fiber.Ctx) Download(file string, filename ...string) error
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
return c.Download("./files/report-12345.pdf")
// => Download report-12345.pdf
return c.Download("./files/report-12345.pdf", "report.pdf")
// => Download report.pdf
})
```
For filenames containing non-ASCII characters, a `filename*` parameter is added
according to [RFC 6266](https://www.rfc-editor.org/rfc/rfc6266) and
[RFC 8187](https://www.rfc-editor.org/rfc/rfc8187):
```go title="Example"
app.Get("/non-ascii", func(c fiber.Ctx) error {
return c.Download("./files/文件.txt")
// => Content-Disposition: attachment; filename="文件.txt"; filename*=UTF-8''%E6%96%87%E4%BB%B6.txt
})
```
### End
End immediately flushes the current response and closes the underlying connection.
```go title="Signature"
func (c fiber.Ctx) End() error
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.SendString("Hello World!")
return c.End()
})
```
:::caution
Calling `c.End()` will disallow further writes to the underlying connection.
:::
:::warning
`c.End()` does **not** work in streaming mode (e.g. when using `fasthttp`'s `HijackConn` or `SendStream`).
In streaming mode the connection is managed asynchronously and `ctx.Conn()` may return `nil`,
so `c.End()` will return `nil` without flushing or closing the connection.
:::
End can be used to stop a middleware from modifying a response of a handler/other middleware down the method chain
when they regain control after calling `c.Next()`.
```go title="Example"
// Error Logging/Responding middleware
app.Use(func(c fiber.Ctx) error {
err := c.Next()
// Log errors & write the error to the response
if err != nil {
log.Printf("Got error in middleware: %v", err)
return c.Writef("(got error %v)", err)
}
// No errors occurred
return nil
})
// Handler with simulated error
app.Get("/", func(c fiber.Ctx) error {
// Closes the connection instantly after writing from this handler
// and disallow further modification of its response
defer c.End()
c.SendString("Hello, ... I forgot what comes next!")
return errors.New("some error")
})
```
### Format
Performs content-negotiation on the [Accept](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept) HTTP header. It uses [Accepts](ctx.md#accepts) to select a proper format from the supplied offers. A default handler can be provided by setting the `MediaType` to `"default"`. If no offers match and no default is provided, a 406 (Not Acceptable) response is sent. The Content-Type is automatically set when a handler is selected.
:::info
If the Accept header is **not** specified, the first handler will be used.
:::
```go title="Signature"
func (c fiber.Ctx) Format(handlers ...ResFmt) error
```
```go title="Example"
// Accept: application/json => {"command":"eat","subject":"fruit"}
// Accept: text/plain => Eat Fruit!
// Accept: application/xml => Not Acceptable
app.Get("/no-default", func(c fiber.Ctx) error {
return c.Format(
fiber.ResFmt{"application/json", func(c fiber.Ctx) error {
return c.JSON(fiber.Map{
"command": "eat",
"subject": "fruit",
})
}},
fiber.ResFmt{"text/plain", func(c fiber.Ctx) error {
return c.SendString("Eat Fruit!")
}},
)
})
// Accept: application/json => {"command":"eat","subject":"fruit"}
// Accept: text/plain => Eat Fruit!
// Accept: application/xml => Eat Fruit!
app.Get("/default", func(c fiber.Ctx) error {
textHandler := func(c fiber.Ctx) error {
return c.SendString("Eat Fruit!")
}
handlers := []fiber.ResFmt{
{"application/json", func(c fiber.Ctx) error {
return c.JSON(fiber.Map{
"command": "eat",
"subject": "fruit",
})
}},
{"text/plain", textHandler},
{"default", textHandler},
}
return c.Format(handlers...)
})
```
### JSON
Converts any **interface** or **string** to JSON using the [encoding/json](https://pkg.go.dev/encoding/json) package.
:::info
JSON also sets the content header to the `ctype` parameter. If no `ctype` is passed in, the header is set to `application/json; charset=utf-8` by default.
:::
```go title="Signature"
func (c fiber.Ctx) JSON(data any, ctype ...string) error
```
```go title="Example"
type SomeStruct struct {
Name string
Age uint8
}
app.Get("/json", func(c fiber.Ctx) error {
// Create data struct:
data := SomeStruct{
Name: "Grame",
Age: 20,
}
return c.JSON(data)
// => Content-Type: application/json; charset=utf-8
// => {"Name": "Grame", "Age": 20}
return c.JSON(fiber.Map{
"name": "Grame",
"age": 20,
})
// => Content-Type: application/json; charset=utf-8
// => {"name": "Grame", "age": 20}
return c.JSON(fiber.Map{
"type": "https://example.com/probs/out-of-credit",
"title": "You do not have enough credit.",
"status": 403,
"detail": "Your current balance is 30, but that costs 50.",
"instance": "/account/12345/msgs/abc",
}, "application/problem+json")
// => Content-Type: application/problem+json
// => "{
// => "type": "https://example.com/probs/out-of-credit",
// => "title": "You do not have enough credit.",
// => "status": 403,
// => "detail": "Your current balance is 30, but that costs 50.",
// => "instance": "/account/12345/msgs/abc",
// => }"
})
```
### JSONP
Sends a JSON response with JSONP support. This method is identical to [JSON](ctx.md#json), except that it opts-in to JSONP callback support. By default, the callback name is simply `callback`.
Override this by passing a **named string** in the method.
```go title="Signature"
func (c fiber.Ctx) JSONP(data any, callback ...string) error
```
```go title="Example"
type SomeStruct struct {
Name string
Age uint8
}
app.Get("/", func(c fiber.Ctx) error {
// Create data struct:
data := SomeStruct{
Name: "Grame",
Age: 20,
}
return c.JSONP(data)
// => callback({"Name": "Grame", "Age": 20})
return c.JSONP(data, "customFunc")
// => customFunc({"Name": "Grame", "Age": 20})
})
```
### Links
Joins the links followed by the property to populate the response’s [Link HTTP header](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link) field.
```go title="Signature"
func (c fiber.Ctx) Links(link ...string)
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Links(
"http://api.example.com/users?page=2", "next",
"http://api.example.com/users?page=5", "last",
)
// Link: ; rel="next",
// ; rel="last"
// ...
})
```
### Location
Sets the response [Location](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Location) HTTP header to the specified path parameter.
```go title="Signature"
func (c fiber.Ctx) Location(path string)
```
```go title="Example"
app.Post("/", func(c fiber.Ctx) error {
c.Location("http://example.com")
c.Location("/foo/bar")
return nil
})
```
### MsgPack
> **Note:** Before using any MsgPack-related features, make sure to follow the [MsgPack setup instructions](../guide/advance-format.md#msgpack).
A compact binary alternative to [JSON](#json) for efficient data transfer between micro-services or from server to client. MessagePack serializes faster and yields smaller payloads than plain JSON.
Converts any **interface** or **string** to MsgPack using the [shamaton/msgpack](https://pkg.go.dev/github.com/shamaton/msgpack/v3) package.
:::info
MsgPack also sets the content header to the `ctype` parameter. If no `ctype` is passed in, the header is set to `application/vnd.msgpack`.
:::
```go title="Signature"
func (c fiber.Ctx) MsgPack(data any, ctype ...string) error
```
```go title="Example"
type SomeStruct struct {
Name string
Age uint8
}
app.Get("/msgpack", func(c fiber.Ctx) error {
// Create data struct:
data := SomeStruct{
Name: "Grame",
Age: 20,
}
return c.MsgPack(data)
// => Content-Type: application/vnd.msgpack
// => 82 A4 4E 61 6D 65 A5 47 72 61 6D 65 A3 41 67 65 14
return c.MsgPack(fiber.Map{
"name": "Grame",
"age": 20,
})
// => Content-Type: application/vnd.msgpack
// => 82 A4 6E 61 6D 65 A5 47 72 61 6D 65 A3 61 67 65 14
return c.MsgPack(fiber.Map{
"type": "https://example.com/probs/out-of-credit",
"title": "You do not have enough credit.",
"status": 403,
"detail": "Your current balance is 30, but that costs 50.",
"instance": "/account/12345/msgs/abc",
}, "application/problem+msgpack")
})
// => Content-Type: application/problem+msgpack
// 85 A4 74 79 70 65 D9 27 68 74 74 70 73 3A 2F 2F 65 78 61 6D 70 6C 65 2E 63 6F 6D 2F 70 72 6F 62 73 2F 6F 75 74 2D 6F 66 2D 63 72 65 64 69 74 A5 74 69 74 6C 65 BE 59 6F 75 20 64 6F 20 6E 6F 74 20 68 61 76 65 20 65 6E 6F 75 67 68 20 63 72 65 64 69 74 2E A6 73 74 61 74 75 73 CD 01 93 A6 64 65 74 61 69 6C D9 2E 59 6F 75 72 20 63 75 72 72 65 6E 74 20 62 61 6C 61 6E 63 65 20 69 73 20 33 30 2C 20 62 75 74 20 74 68 61 74 20 63 6F 73 74 73 20 35 30 2E A8 69 6E 73 74 61 6E 63 65 B7 2F 61 63 63 6F 75 6E 74 2F 31 32 33 34 35 2F 6D 73 67 73 2F 61 62 63
```
### Render
Renders a view with data and sends a `text/html` response. By default, `Render` uses the default [**Go Template engine**](https://pkg.go.dev/html/template/). If you want to use another view engine, please take a look at our [**Template middleware**](https://docs.gofiber.io/template).
```go title="Signature"
func (c fiber.Ctx) Render(name string, bind any, layouts ...string) error
```
### Send
Sets the HTTP response body.
```go title="Signature"
func (c fiber.Ctx) Send(body []byte) error
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
return c.Send([]byte("Hello, World!")) // => "Hello, World!"
})
```
Fiber also provides `SendString` and `SendStream` methods for raw inputs.
:::tip
Use this if you **don't need** type assertion, recommended for **faster** performance.
:::
```go title="Signature"
func (c fiber.Ctx) SendString(body string) error
func (c fiber.Ctx) SendStream(stream io.Reader, size ...int) error
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
// => "Hello, World!"
return c.SendStream(bytes.NewReader([]byte("Hello, World!")))
// => "Hello, World!"
})
```
### SendEarlyHints
Sends an informational `103 Early Hints` response with one or more
[`Link` headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Link)
before the final response. This allows the browser to start preloading
resources while the server prepares the full response.
:::caution
This feature requires HTTP/2 or newer. Some legacy HTTP/1.1 clients may not support sendEarlyHints.
Early Hints (`103` responses) are supported in HTTP/2 and newer. Older HTTP/1.1 clients may ignore these interim responses or misbehave when receiving them.
See [Enabling HTTP/2](../guide/reverse-proxy#enabling-http2) for instructions on how to use a reverse proxy (e.g. Nginx or Traefik) to enable HTTP/2 support.
:::
```go title="Signature"
func (c fiber.Ctx) SendEarlyHints(hints []string) error
```
```go title="Example"
hints := []string{"; rel=preload; as=script"}
app.Get("/early", func(c fiber.Ctx) error {
if err := c.SendEarlyHints(hints); err != nil {
return err
}
return c.SendString("done")
})
```
### SendFile
Transfers the file from the given path. Sets the [Content-Type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) response HTTP header field based on the **file** extension or format.
```go title="Config" title="Config"
// SendFile defines configuration options when to transfer file with SendFile.
type SendFile struct {
// FS is the file system to serve the static files from.
// You can use interfaces compatible with fs.FS like embed.FS, os.DirFS etc.
//
// Optional. Default: nil
FS fs.FS
// When set to true, the server tries minimizing CPU usage by caching compressed files.
// This works differently than the github.com/gofiber/compression middleware.
// You have to set Content-Encoding header to compress the file.
// Available compression methods are gzip, br, and zstd.
//
// Optional. Default: false
Compress bool `json:"compress"`
// When set to true, enables byte range requests.
//
// Optional. Default: false
ByteRange bool `json:"byte_range"`
// When set to true, enables direct download.
//
// Optional. Default: false
Download bool `json:"download"`
// Expiration duration for inactive file handlers.
// Use a negative time.Duration to disable it.
//
// Optional. Default: 10 * time.Second
CacheDuration time.Duration `json:"cache_duration"`
// The value for the Cache-Control HTTP-header
// that is set on the file response. MaxAge is defined in seconds.
//
// Optional. Default: 0
MaxAge int `json:"max_age"`
}
```
```go title="Signature" title="Signature"
func (c fiber.Ctx) SendFile(file string, config ...SendFile) error
```
```go title="Example"
app.Get("/not-found", func(c fiber.Ctx) error {
return c.SendFile("./public/404.html")
// Disable compression
return c.SendFile("./static/index.html", fiber.SendFile{
Compress: false,
})
})
```
:::info
If the file contains a URL-specific character, you have to escape it before passing the file path into the `SendFile` function.
:::
```go title="Example"
app.Get("/file-with-url-chars", func(c fiber.Ctx) error {
return c.SendFile(url.PathEscape("hash_sign_#.txt"))
})
```
:::info
You can set the `CacheDuration` config property to `-1` to disable caching.
:::
```go title="Example"
app.Get("/file", func(c fiber.Ctx) error {
return c.SendFile("style.css", fiber.SendFile{
CacheDuration: -1,
})
})
```
:::info
You can use multiple `SendFile` calls with different configurations in a single route. Fiber creates different filesystem handlers per config.
:::
```go title="Example"
app.Get("/file", func(c fiber.Ctx) error {
switch c.Query("config") {
case "filesystem":
return c.SendFile("style.css", fiber.SendFile{
FS: os.DirFS(".")
})
case "filesystem-compress":
return c.SendFile("style.css", fiber.SendFile{
FS: os.DirFS("."),
Compress: true,
})
case "compress":
return c.SendFile("style.css", fiber.SendFile{
Compress: true,
})
default:
return c.SendFile("style.css")
}
return nil
})
```
:::info
For sending multiple files from an embedded file system, [this functionality](../middleware/static.md#serving-files-using-embedfs) can be used.
:::
### SendStatus
Sets the status code and the correct status message in the body if the response body is **empty**.
:::tip
You can find all used status codes and messages [in the Fiber source code](https://github.com/gofiber/fiber/blob/dffab20bcdf4f3597d2c74633a7705a517d2c8c2/utils.go#L183-L244).
:::
```go title="Signature"
func (c fiber.Ctx) SendStatus(status int) error
```
```go title="Example"
app.Get("/not-found", func(c fiber.Ctx) error {
return c.SendStatus(415)
// => 415 "Unsupported Media Type"
c.SendString("Hello, World!")
return c.SendStatus(415)
// => 415 "Hello, World!"
})
```
### SendStream
Sets the response body to a stream of data and adds an optional body size.
```go title="Signature"
func (c fiber.Ctx) SendStream(stream io.Reader, size ...int) error
```
:::info
`SendStream` operates asynchronously. The handler returns immediately after setting up the stream,
but the actual reading and sending of data happens **after** the handler completes. This is handled
by the underlying `fasthttp` library.
If the provided stream implements `io.Closer`, it will be automatically closed by `fasthttp` after
the response is fully sent or if an error occurs.
:::
:::caution
When passing `fiber.Ctx` as a `context.Context` to libraries that spawn goroutines (e.g., for streaming operations),
those goroutines may attempt to access the context after the handler returns. Since `fiber.Ctx` is recycled and
released after the handler completes, this can cause issues.
**Recommended approach**: Use `c.Context()` or `c.RequestCtx()` instead of passing `c` directly to such libraries.
See the [Context Guide](../guide/context.md) for more details.
:::
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
return c.SendStream(bytes.NewReader([]byte("Hello, World!")))
// => "Hello, World!"
})
```
```go title="Example with file streaming"
app.Get("/download", func(c fiber.Ctx) error {
file, err := os.Open("large-file.zip")
if err != nil {
return err
}
// File will be automatically closed by fasthttp after streaming completes
stat, err := file.Stat()
if err != nil {
file.Close()
return err
}
return c.SendStream(file, int(stat.Size()))
})
```
### SendStreamWriter
Sets the response body stream writer.
:::note
The argument `streamWriter` represents a function that populates
the response body using a buffered stream writer.
:::
```go title="Signature"
func (c Ctx) SendStreamWriter(streamWriter func(*bufio.Writer)) error
```
```go title="Example"
app.Get("/", func (c fiber.Ctx) error {
return c.SendStreamWriter(func(w *bufio.Writer) {
fmt.Fprintf(w, "Hello, World!\n")
})
// => "Hello, World!"
})
```
:::info
To send data before `streamWriter` returns, you can call `w.Flush()`
on the provided writer. Otherwise, the buffered stream flushes after
`streamWriter` returns.
:::
:::note
`w.Flush()` will return an error if the client disconnects before `streamWriter` finishes writing a response.
:::
```go title="Example"
app.Get("/wait", func(c fiber.Ctx) error {
return c.SendStreamWriter(func(w *bufio.Writer) {
// Begin Work
fmt.Fprintf(w, "Please wait for 10 seconds\n")
if err := w.Flush(); err != nil {
log.Print("Client disconnected!")
return
}
// Send progress over time
time.Sleep(time.Second)
for i := 0; i < 9; i++ {
fmt.Fprintf(w, "Still waiting...\n")
if err := w.Flush(); err != nil {
// If client disconnected, cancel work and finish
log.Print("Client disconnected!")
return
}
time.Sleep(time.Second)
}
// Finish
fmt.Fprintf(w, "Done!\n")
})
})
```
### SendString
Sets the response body to a string.
```go title="Signature"
func (c fiber.Ctx) SendString(body string) error
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
// => "Hello, World!"
})
```
### Set
Sets the response’s HTTP header field to the specified `key`, `value`.
```go title="Signature"
func (c fiber.Ctx) Set(key string, val string)
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Set("Content-Type", "text/plain")
// => "Content-Type: text/plain"
// ...
})
```
### Status
Sets the HTTP status for the response.
:::info
This method is **chainable**.
:::
```go title="Signature"
func (c fiber.Ctx) Status(status int) fiber.Ctx
```
```go title="Example"
app.Get("/fiber", func(c fiber.Ctx) error {
c.Status(fiber.StatusOK)
return nil
})
app.Get("/hello", func(c fiber.Ctx) error {
return c.Status(fiber.StatusBadRequest).SendString("Bad Request")
})
app.Get("/world", func(c fiber.Ctx) error {
return c.Status(fiber.StatusNotFound).SendFile("./public/gopher.png")
})
```
### Type
Sets the [Content-Type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) HTTP header to the MIME type listed [in the Nginx MIME types configuration](https://github.com/nginx/nginx/blob/master/conf/mime.types) specified by the file **extension**.
:::info
This method is **chainable**.
:::
```go title="Signature"
func (c fiber.Ctx) Type(ext string, charset ...string) fiber.Ctx
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Type(".html") // => "text/html"
c.Type("html") // => "text/html"
c.Type("png") // => "image/png"
c.Type("json", "utf-8") // => "application/json; charset=utf-8"
// ...
})
```
### Vary
Adds the given header field to the [Vary](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary) response header. This will append the header if not already listed; otherwise, it leaves it listed in the current location.
:::info
Multiple fields are **allowed**.
:::
```go title="Signature"
func (c fiber.Ctx) Vary(fields ...string)
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Vary("Origin") // => Vary: Origin
c.Vary("User-Agent") // => Vary: Origin, User-Agent
// No duplicates
c.Vary("Origin") // => Vary: Origin, User-Agent
c.Vary("Accept-Encoding", "Accept")
// => Vary: Origin, User-Agent, Accept-Encoding, Accept
// ...
})
```
### Write
Adopts the `Writer` interface.
```go title="Signature"
func (c fiber.Ctx) Write(p []byte) (n int, err error)
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
c.Write([]byte("Hello, World!")) // => "Hello, World!"
fmt.Fprintf(c, "%s\n", "Hello, World!") // => "Hello, World!"
})
```
### Writef
Writes a formatted string using a format specifier.
```go title="Signature"
func (c fiber.Ctx) Writef(format string, a ...any) (n int, err error)
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
world := "World!"
c.Writef("Hello, %s", world) // => "Hello, World!"
fmt.Fprintf(c, "%s\n", "Hello, World!") // => "Hello, World!"
})
```
### WriteString
Writes a string to the response body.
```go title="Signature"
func (c fiber.Ctx) WriteString(s string) (n int, err error)
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
return c.WriteString("Hello, World!")
// => "Hello, World!"
})
```
### XML
Converts any **interface** or **string** to XML using the standard `encoding/xml` package.
:::info
XML also sets the content header to `application/xml; charset=utf-8`.
:::
```go title="Signature"
func (c fiber.Ctx) XML(data any) error
```
```go title="Example"
type SomeStruct struct {
XMLName xml.Name `xml:"Fiber"`
Name string `xml:"Name"`
Age uint8 `xml:"Age"`
}
app.Get("/", func(c fiber.Ctx) error {
// Create data struct:
data := SomeStruct{
Name: "Grame",
Age: 20,
}
return c.XML(data)
//
// Grame
// 20
//
})
```
================================================
FILE: docs/api/fiber.md
================================================
---
id: fiber
title: 📦 Fiber
description: Fiber represents the fiber package where you start to create an instance.
sidebar_position: 1
---
import Reference from '@site/src/components/reference';
## Server start
### New
This method creates a new **App** named instance. You can pass optional [config](#config) when creating a new instance.
```go title="Signature"
func New(config ...Config) *App
```
```go title="Example"
// Default config
app := fiber.New()
// ...
```
### Config
You can pass an optional Config when creating a new Fiber instance.
```go title="Example"
// Custom config
app := fiber.New(fiber.Config{
CaseSensitive: true,
StrictRouting: true,
ServerHeader: "Fiber",
AppName: "Test App v1.0.1",
})
// ...
```
#### Config fields
| Property | Type | Description | Default |
|---------------------------------------------------------------------------------------|-----------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------|
| AppName | `string` | Sets the application name used in logs and the Server header | `""` |
| BodyLimit | `int` | Sets the maximum allowed size for a request body. Zero or negative values fall back to the default limit. If the size exceeds the configured limit, it sends `413 - Request Entity Too Large` response. This limit also applies when running Fiber through the adaptor middleware from `net/http`. | `4 * 1024 * 1024` |
| CaseSensitive | `bool` | When enabled, `/Foo` and `/foo` are different routes. When disabled, `/Foo` and `/foo` are treated the same. | `false` |
| CBORDecoder | `utils.CBORUnmarshal` | Allowing for flexibility in using another cbor library for decoding. | `binder.UnimplementedCborUnmarshal` |
| CBOREncoder | `utils.CBORMarshal` | Allowing for flexibility in using another cbor library for encoding. | `binder.UnimplementedCborMarshal` |
| ColorScheme | [`Colors`](https://github.com/gofiber/fiber/blob/main/color.go) | You can define custom color scheme. They'll be used for startup message, route list and some middlewares. | [`DefaultColors`](https://github.com/gofiber/fiber/blob/main/color.go) |
| CompressedFileSuffixes | `map[string]string` | Adds a suffix to the original file name and tries saving the resulting compressed file under the new file name. | `{"gzip": ".fiber.gz", "br": ".fiber.br", "zstd": ".fiber.zst"}` |
| Concurrency | `int` | Maximum number of concurrent connections. | `256 * 1024` |
| DisableDefaultContentType | `bool` | When true, omits the default Content-Type header from the response. | `false` |
| DisableDefaultDate | `bool` | When true, omits the Date header from the response. | `false` |
| DisableHeadAutoRegister | `bool` | Prevents Fiber from automatically registering `HEAD` routes for each `GET` route so you can supply custom `HEAD` handlers; manual `HEAD` routes still override the generated ones. | `false` |
| DisableHeaderNormalizing | `bool` | By default all header names are normalized: conteNT-tYPE -> Content-Type | `false` |
| DisableKeepalive | `bool` | Disables keep-alive connections so the server closes each connection after the first response. | `false` |
| DisablePreParseMultipartForm | `bool` | Will not pre parse Multipart Form data if set to true. This option is useful for servers that desire to treat multipart form data as a binary blob, or choose when to parse the data. | `false` |
| EnableIPValidation | `bool` | If set to true, `c.IP()` and `c.IPs()` will validate IP addresses before returning them. Also, `c.IP()` will return only the first valid IP rather than just the raw header value that may be a comma separated string.
**WARNING:** There is a small performance cost to doing this validation. Keep disabled if speed is your only concern and your application is behind a trusted proxy that already validates this header. | `false` |
| EnableSplittingOnParsers | `bool` | Splits query, body, and header parameters on commas when enabled.
For example, `/api?foo=bar,baz` becomes `foo[]=bar&foo[]=baz`. | `false` |
| ErrorHandler | `ErrorHandler` | ErrorHandler is executed when an error is returned from fiber.Handler. Mounted fiber error handlers are retained by the top-level app and applied on prefix associated requests. | `DefaultErrorHandler` |
| GETOnly | `bool` | Rejects all non-GET requests if set to true. This option is useful as anti-DoS protection for servers accepting only GET requests. The request size is limited by ReadBufferSize if GETOnly is set. | `false` |
| IdleTimeout | `time.Duration` | The maximum amount of time to wait for the next request when keep-alive is enabled. If IdleTimeout is zero, the value of ReadTimeout is used. | `0` |
| Immutable | `bool` | When enabled, all values returned by context methods are immutable. By default, they are valid until you return from the handler; see issue [\#185](https://github.com/gofiber/fiber/issues/185). | `false` |
| JSONDecoder | `utils.JSONUnmarshal` | Allowing for flexibility in using another json library for decoding. | `json.Unmarshal` |
| JSONEncoder | `utils.JSONMarshal` | Allowing for flexibility in using another json library for encoding. | `json.Marshal` |
| MaxRanges | `int` | Sets the maximum number of ranges parsed from a `Range` header. Zero or negative values fall back to the default limit. If the limit is exceeded, the request is rejected with `416 - Requested Range Not Satisfiable` and `Content-Range: bytes */`. | `16` |
| MsgPackDecoder | `utils.MsgPackUnmarshal` | Allowing for flexibility in using another msgpack library for decoding. | `binder.UnimplementedMsgpackUnmarshal` |
| MsgPackEncoder | `utils.MsgPackMarshal` | Allowing for flexibility in using another msgpack library for encoding. | `binder.UnimplementedMsgpackMarshal` |
| PassLocalsToContext | `bool` | Controls whether `StoreInContext` also propagates values into the request `context.Context` for Fiber-backed contexts. `StoreInContext` always writes to `c.Locals()`. `ValueFromContext` for Fiber-backed contexts always reads from `c.Locals()`. | `false` |
| PassLocalsToViews | `bool` | PassLocalsToViews Enables passing of the locals set on a fiber.Ctx to the template engine. See our **Template Middleware** for supported engines. | `false` |
| ProxyHeader | `string` | Specifies the header name to read the client's real IP address from when behind a reverse proxy. Common values: `fiber.HeaderXForwardedFor`, `"X-Real-IP"`, `"CF-Connecting-IP"` (Cloudflare).
**Important:** This setting **requires** `TrustProxy` to be enabled; `TrustProxyConfig` controls which proxy IPs are trusted for reading this header. Without `TrustProxy`, this setting has no effect and `c.IP()` will always return the remote IP from the TCP connection.
**Behavior note:** `X-Forwarded-For` often contains a comma-separated chain of IP addresses. With the default `EnableIPValidation = false`, `c.IP()` will return the raw header value (the whole chain) rather than a single parsed client IP. With `EnableIPValidation = true`, `c.IP()` parses the header and returns the **first syntactically valid IP address** it finds; it does **not** walk the chain to find the first non-proxy hop. For a reliable client IP, configure your reverse proxy to overwrite or sanitize this header and/or to provide a single-IP header such as `"X-Real-IP"` or a provider-specific header like `"CF-Connecting-IP"`.
**Security Warning:** Headers can be easily spoofed. Always configure `TrustProxyConfig` to validate the proxy IP address, otherwise malicious clients can forge headers to bypass IP-based access controls. | `""` |
| ReadBufferSize | `int` | per-connection buffer size for requests' reading. This also limits the maximum header size. Increase this buffer if your clients send multi-KB RequestURIs and/or multi-KB headers \(for example, BIG cookies\). | `4096` |
| ReadTimeout | `time.Duration` | The amount of time allowed to read the full request, including the body. The default timeout is unlimited. | `0` |
| ReduceMemoryUsage | `bool` | Aggressively reduces memory usage at the cost of higher CPU usage if set to true. | `false` |
| RequestMethods | `[]string` | RequestMethods provides customizability for HTTP methods. You can add/remove methods as you wish. | `DefaultMethods` |
| ServerHeader | `string` | Enables the `Server` HTTP header with the given value. | `""` |
| StreamRequestBody | `bool` | StreamRequestBody enables request body streaming, and calls the handler sooner when given body is larger than the current limit. | `false` |
| StrictRouting | `bool` | When enabled, the router treats `/foo` and `/foo/` as different. Otherwise, the router treats `/foo` and `/foo/` as the same. | `false` |
| StructValidator | `StructValidator` | If you want to validate header/form/query... automatically when to bind, you can define struct validator. Fiber doesn't have default validator, so it'll skip validator step if you don't use any validator. | `nil` |
| TrustProxy | `bool` | Enables trust of reverse proxy headers. When enabled, Fiber will check if the request is coming from a trusted proxy (configured in `TrustProxyConfig`) before reading values from proxy headers.
**Required for**: Using `ProxyHeader` to read client IP from headers like `X-Forwarded-For`.
**Behavior when enabled:** If the remote IP is trusted (matches `TrustProxyConfig`), then `c.IP()` reads from `ProxyHeader` (when configured; otherwise it uses `RemoteIP()`), `c.Scheme()` first checks standard proxy scheme headers (`X-Forwarded-Proto`, `X-Forwarded-Protocol`, `X-Forwarded-Ssl`, `X-Url-Scheme`) and falls back to the actual connection scheme if none are set, and `c.Hostname()` prefers `X-Forwarded-Host` but falls back to the request Host header when the proxy header is not present. If the remote IP is NOT trusted, these methods ignore proxy headers and use the actual connection values instead.
**Security:** This prevents header spoofing by validating the proxy's IP address. Always configure `TrustProxyConfig` when enabling this option and set `ProxyHeader` if you want `c.IP()` to use a specific header. | `false` |
| TrustProxyConfig | `TrustProxyConfig` | Configures which proxy IP addresses or ranges to trust. Only effective when `TrustProxy` is enabled.
**Example:** For an app behind Nginx at 10.10.0.58, use `TrustProxyConfig{Proxies: []string{"10.10.0.58"}}` or `TrustProxyConfig{Private: true}` if using private network IPs. | `{}` |
| UnescapePath | `bool` | Converts all encoded characters in the route back before setting the path for the context, so that the routing can also work with URL encoded special characters | `false` |
| Views | `Views` | Views is the interface that wraps the Render function. See our **Template Middleware** for supported engines. | `nil` |
| ViewsLayout | `string` | Views Layout is the global layout for all template render until override on Render function. See our **Template Middleware** for supported engines. | `""` |
| WriteBufferSize | `int` | Per-connection buffer size for responses' writing. | `4096` |
| WriteTimeout | `time.Duration` | The maximum duration before timing out writes of the response. The default timeout is unlimited. | `0` |
| XMLDecoder | `utils.XMLUnmarshal` | Allowing for flexibility in using another XML library for decoding. | `xml.Unmarshal` |
| XMLEncoder | `utils.XMLMarshal` | Allowing for flexibility in using another XML library for encoding. | `xml.Marshal` |
## Server listening
### Config
You can pass an optional ListenConfig when calling the [`Listen`](#listen) or [`Listener`](#listener) method.
```go title="Example"
// Custom config
app.Listen(":8080", fiber.ListenConfig{
EnablePrefork: true,
DisableStartupMessage: true,
})
```
#### Config fields
| Property | Type | Description | Default |
|-------------------------------------------------------------------------|-------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------|
| BeforeServeFunc | `func(app *App) error` | Allows customizing and accessing fiber app before serving the app. | `nil` |
| CertClientFile | `string` | Path of the client certificate. If you want to use mTLS, you must enter this field. | `""` |
| CertFile | `string` | Path of the certificate file. If you want to use TLS, you must enter this field. | `""` |
| CertKeyFile | `string` | Path of the certificate's private key. If you want to use TLS, you must enter this field. | `""` |
| DisableStartupMessage | `bool` | When set to true, it will not print out the «Fiber» ASCII art and listening address. | `false` |
| EnablePrefork | `bool` | When set to true, this will spawn multiple Go processes listening on the same port. | `false` |
| EnablePrintRoutes | `bool` | If set to true, will print all routes with their method, path, and handler. | `false` |
| GracefulContext | `context.Context` | Field to shutdown Fiber by given context gracefully. | `nil` |
| ShutdownTimeout | `time.Duration` | Specifies the maximum duration to wait for the server to gracefully shutdown. When the timeout is reached, the graceful shutdown process is interrupted and forcibly terminated, and the `context.DeadlineExceeded` error is passed to the `OnPostShutdown` callback. Set to 0 to disable the timeout and wait indefinitely. | `10 * time.Second` |
| ListenerAddrFunc | `func(addr net.Addr)` | Allows accessing and customizing `net.Listener`. | `nil` |
| ListenerNetwork | `string` | Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), "unix" (Unix Domain Sockets). WARNING: When prefork is set to true, only "tcp4" and "tcp6" can be chosen. | `tcp4` |
| UnixSocketFileMode | `os.FileMode` | FileMode to set for Unix Domain Socket (ListenerNetwork must be "unix") | `0770` |
| TLSConfigFunc | `func(tlsConfig *tls.Config)` | Allows customizing `tls.Config` as you want. Ignored when `TLSConfig` is set. | `nil` |
| TLSConfig | `*tls.Config` | Recommended base TLS configuration (cloned). Use for external certificate providers via `GetCertificate`. When set, other TLS fields are ignored. | `nil` |
| AutoCertManager | `*autocert.Manager` | Manages TLS certificates automatically using the ACME protocol. Enables integration with Let's Encrypt or other ACME-compatible providers. | `nil` |
| TLSMinVersion | `uint16` | Allows customizing the TLS minimum version. | `tls.VersionTLS12` |
### Listen
Listen serves HTTP requests from the given address.
```go title="Signature"
func (app *App) Listen(addr string, config ...ListenConfig) error
```
```go title="Basic Listen usage"
// Listen on port :8080
app.Listen(":8080")
// Listen on port :8080 with Prefork
app.Listen(":8080", fiber.ListenConfig{EnablePrefork: true})
// Custom host
app.Listen("127.0.0.1:8080")
```
#### Prefork
Prefork is a feature that allows you to spawn multiple Go processes listening on the same port. This can be useful for scaling across multiple CPU cores.
```go title="Prefork listener"
app.Listen(":8080", fiber.ListenConfig{EnablePrefork: true})
```
This distributes the incoming connections between the spawned processes and allows more requests to be handled simultaneously.
#### TLS
Prefer `TLSConfig` for TLS configuration so you can fully control certificates and settings. When `TLSConfig` is set, Fiber ignores `CertFile`, `CertKeyFile`, `CertClientFile`, `TLSMinVersion`, `AutoCertManager`, and `TLSConfigFunc`.
TLS serves HTTPs requests from the given address using certFile and keyFile paths as TLS certificate and key file.
```go title="TLS with cert and key files"
app.Listen(":443", fiber.ListenConfig{CertFile: "./cert.pem", CertKeyFile: "./cert.key"})
```
#### TLS with client CA certificate
`CertClientFile` only configures the client CA for mTLS when using `CertFile`/`CertKeyFile`. If `TLSConfig` is set, `CertClientFile` is ignored, so configure client CAs in the provided `tls.Config` instead.
```go title="TLS with client CA certificate"
app.Listen(":443", fiber.ListenConfig{
CertFile: "./cert.pem",
CertKeyFile: "./cert.key",
CertClientFile: "./ca-chain-cert.pem",
})
```
#### TLS AutoCert support (ACME / Let's Encrypt)
Provides automatic access to certificates management from Let's Encrypt and any other ACME-based providers.
```go title="AutoCert (ACME) configuration"
// Certificate manager
certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
// Replace with your domain name
HostPolicy: autocert.HostWhitelist("example.com"),
// Folder to store the certificates
Cache: autocert.DirCache("./certs"),
}
app.Listen(":444", fiber.ListenConfig{
AutoCertManager: certManager,
})
```
#### Precedence and conflicts
- `TLSConfig` is preferred and ignores `CertFile`/`CertKeyFile`, `CertClientFile`, `AutoCertManager`, `TLSMinVersion`, and `TLSConfigFunc`.
- `AutoCertManager` cannot be combined with `CertFile`/`CertKeyFile`.
#### TLS with external certificate provider
Use `TLSConfig` to supply a base `tls.Config` that can fetch certificates at runtime. `TLSConfig` is cloned and used as-is.
```go title="TLSConfig with dynamic certificate provider"
app.Listen(":443", fiber.ListenConfig{
TLSConfig: &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return myProvider.Certificate(info.ServerName)
},
},
})
```
#### Mutual TLS with TLSConfig
Use `TLSConfig` to configure mutual TLS by setting `ClientAuth` and `ClientCAs`. This replaces `CertClientFile` when you manage TLS configuration directly.
```go title="TLSConfig with client CA pool"
certPEM := []byte(certPEMString)
keyPEM := []byte(keyPEMString)
caPEM := []byte(caPEMString)
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
log.Fatal(err)
}
clientCAs := x509.NewCertPool()
if ok := clientCAs.AppendCertsFromPEM(caPEM); !ok {
log.Fatal("failed to append client CA")
}
app.Listen(":443", fiber.ListenConfig{
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: clientCAs,
},
})
```
Load certificates from memory or environment variables and provide them via `TLSConfig`.
```go title="TLSConfig with in-memory certificate"
certPEM := []byte(certPEMString)
keyPEM := []byte(keyPEMString)
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
log.Fatal(err)
}
app.Listen(":443", fiber.ListenConfig{
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
})
```
```go title="TLSConfig with certificate from environment"
certPEM := []byte(os.Getenv("TLS_CERT_PEM"))
keyPEM := []byte(os.Getenv("TLS_KEY_PEM"))
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
log.Fatal(err)
}
app.Listen(":443", fiber.ListenConfig{
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
})
```
### Listener
You can pass your own [`net.Listener`](https://pkg.go.dev/net/#Listener) using the `Listener` method. This method can be used to enable **TLS/HTTPS** with a custom tls.Config.
```go title="Signature"
func (app *App) Listener(ln net.Listener, config ...ListenConfig) error
```
```go title="Examples"
ln, _ := net.Listen("tcp", ":3000")
cer, _:= tls.LoadX509KeyPair("server.crt", "server.key")
ln = tls.NewListener(ln, &tls.Config{Certificates: []tls.Certificate{cer}})
app.Listener(ln)
```
## Server
Server returns the underlying [fasthttp server](https://godoc.org/github.com/valyala/fasthttp#Server)
```go title="Signature"
func (app *App) Server() *fasthttp.Server
```
```go title="Examples"
func main() {
app := fiber.New()
app.Server().MaxConnsPerIP = 1
// ...
}
```
## Server Shutdown
Shutdown gracefully shuts down the server without interrupting any active connections. Shutdown works by first closing all open listeners and then waits indefinitely for all connections to return to idle before shutting down.
ShutdownWithTimeout will forcefully close any active connections after the timeout expires.
ShutdownWithContext shuts down the server including by force if the context's deadline is exceeded. Shutdown hooks will still be executed, even if an error occurs during the shutdown process, as they are deferred to ensure cleanup happens regardless of errors.
```go
func (app *App) Shutdown() error
func (app *App) ShutdownWithTimeout(timeout time.Duration) error
func (app *App) ShutdownWithContext(ctx context.Context) error
```
## Helper functions
### NewError
NewError creates a new HTTPError instance with an optional message.
```go title="Signature"
func NewError(code int, message ...string) *Error
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
return fiber.NewError(782, "Custom error message")
})
```
### NewErrorf
NewErrorf creates a new HTTPError instance with an optional formatted message.
```go title="Signature"
func NewErrorf(code int, message ...any) *Error
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
return fiber.NewErrorf(782, "Custom error %s", "message")
})
```
### IsChild
IsChild determines if the current process is a result of Prefork.
```go title="Signature"
func IsChild() bool
```
```go title="Example"
// Config app
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
if !fiber.IsChild() {
fmt.Println("I'm the parent process")
} else {
fmt.Println("I'm a child process")
}
return c.SendString("Hello, World!")
})
// ...
// With prefork enabled, the parent process will spawn child processes
app.Listen(":8080", fiber.ListenConfig{EnablePrefork: true})
```
================================================
FILE: docs/api/hooks.md
================================================
---
id: hooks
title: 🎣 Hooks
sidebar_position: 7
---
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
Fiber lets you run custom callbacks at specific points in the routing lifecycle. Available hooks include:
- [OnRoute](#onroute)
- [OnName](#onname)
- [OnGroup](#ongroup)
- [OnGroupName](#ongroupname)
- [OnListen](#onlisten)
- [OnPreStartupMessage/OnPostStartupMessage](#onprestartupmessageonpoststartupmessage)
- [ListenData](#listendata)
- [OnFork](#onfork)
- [OnPreShutdown](#onpreshutdown)
- [OnPostShutdown](#onpostshutdown)
- [OnMount](#onmount)
## Constants
```go
// Handlers define functions to create hooks for Fiber.
type OnRouteHandler = func(Route) error
type OnNameHandler = OnRouteHandler
type OnGroupHandler = func(Group) error
type OnGroupNameHandler = OnGroupHandler
type OnListenHandler = func(ListenData) error
type OnForkHandler = func(int) error
type OnPreStartupMessageHandler = func(*PreStartupMessageData) error
type OnPostStartupMessageHandler = func(*PostStartupMessageData) error
type OnPreShutdownHandler = func() error
type OnPostShutdownHandler = func(error) error
type OnMountHandler = func(*App) error
```
## OnRoute
Runs after each route is registered. The callback receives the route so you can inspect its properties.
```go title="Signature"
func (h *Hooks) OnRoute(handler ...OnRouteHandler)
```
## OnName
Runs when a route is named. The callback receives the route.
:::caution
`OnName` only works with named routes, not groups.
:::
```go title="Signature"
func (h *Hooks) OnName(handler ...OnNameHandler)
```
```go
package main
import (
"fmt"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(c.Route().Name)
}).Name("index")
app.Hooks().OnName(func(r fiber.Route) error {
fmt.Print("Name: " + r.Name + ", ")
return nil
})
app.Hooks().OnName(func(r fiber.Route) error {
fmt.Print("Method: " + r.Method + "\n")
return nil
})
app.Get("/add/user", func(c fiber.Ctx) error {
return c.SendString(c.Route().Name)
}).Name("addUser")
app.Delete("/destroy/user", func(c fiber.Ctx) error {
return c.SendString(c.Route().Name)
}).Name("destroyUser")
app.Listen(":5000")
}
// Results:
// Name: addUser, Method: GET
// Name: destroyUser, Method: DELETE
```
## OnGroup
Runs after each group is registered. The callback receives the group.
```go title="Signature"
func (h *Hooks) OnGroup(handler ...OnGroupHandler)
```
## OnGroupName
Runs when a group is named. The callback receives the group.
:::caution
`OnGroupName` only works with named groups, not routes.
:::
```go title="Signature"
func (h *Hooks) OnGroupName(handler ...OnGroupNameHandler)
```
## OnListen
Runs when the app starts listening via `Listen` or `Listener`.
```go title="Signature"
func (h *Hooks) OnListen(handler ...OnListenHandler)
```
```go
package main
import (
"log"
"os"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
)
func main() {
app := fiber.New(fiber.Config{
DisableStartupMessage: true,
})
app.Hooks().OnListen(func(listenData fiber.ListenData) error {
if fiber.IsChild() {
return nil
}
scheme := "http"
if listenData.TLS {
scheme = "https"
}
log.Println(scheme + "://" + listenData.Host + ":" + listenData.Port)
return nil
})
app.Listen(":5000")
}
```
## OnPreStartupMessage/OnPostStartupMessage
Use `OnPreStartupMessage` to tweak the banner before Fiber prints it, and `OnPostStartupMessage` to run logic after the banner is printed (or skipped). You can use some helper functions to customize the banner inside the `OnPreStartupMessage` hook.
```go title="Signatures"
// AddInfo adds an informational entry to the startup message with "INFO" label.
func (sm *PreStartupMessageData) AddInfo(key, title, value string, priority ...int)
// AddWarning adds a warning entry to the startup message with "WARNING" label.
func (sm *PreStartupMessageData) AddWarning(key, title, value string, priority ...int)
// AddError adds an error entry to the startup message with "ERROR" label.
func (sm *PreStartupMessageData) AddError(key, title, value string, priority ...int)
// EntryKeys returns all entry keys currently present in the startup message.
func (sm *PreStartupMessageData) EntryKeys() []string
// ResetEntries removes all existing entries from the startup message.
func (sm *PreStartupMessageData) ResetEntries()
// DeleteEntry removes a specific entry from the startup message by its key.
func (sm *PreStartupMessageData) DeleteEntry(key string)
```
- Assign `sm.BannerHeader` to override the ASCII art banner. Leave it empty to use the default banner provided by Fiber.
- Set `sm.PreventDefault = true` to suppress the built-in banner without affecting other hooks.
- `PostStartupMessageData` reports whether the banner was skipped via the `Disabled`, `IsChild`, and `Prevented` flags.
### Startup Message Customization
```go title="Customize the startup message"
package main
import (
"fmt"
"os"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Hooks().OnPreStartupMessage(func(sm *fiber.PreStartupMessageData) error {
sm.BannerHeader = "FOOBER " + sm.Version + "\n-------"
// Optional: you can also remove old entries
// sm.ResetEntries()
sm.AddInfo("git-hash", "Git hash", os.Getenv("GIT_HASH"))
sm.AddInfo("prefork", "Prefork", fmt.Sprintf("%v", sm.Prefork), 15)
return nil
})
app.Hooks().OnPostStartupMessage(func(sm fiber.PostStartupMessageData) error {
if !sm.Disabled && !sm.IsChild && !sm.Prevented {
fmt.Println("startup completed")
}
return nil
})
app.Listen(":5000")
}
```
### ListenData
`ListenData` exposes runtime metadata about the listener:
| Field | Type | Description |
| --- | --- | --- |
| `Host` | `string` | Resolved hostname or IP address. |
| `Port` | `string` | The bound port. |
| `TLS` | `bool` | Indicates whether TLS is enabled. |
| `Version` | `string` | Fiber version reported in the startup banner. |
| `AppName` | `string` | Application name from the configuration. |
| `HandlerCount` | `int` | Total registered handler count. |
| `ProcessCount` | `int` | Number of processes Fiber will use. |
| `PID` | `int` | Current process identifier. |
| `Prefork` | `bool` | Whether prefork is enabled. |
| `ChildPIDs` | `[]int` | Child process identifiers when preforking. |
| `ColorScheme` | [`Colors`](https://github.com/gofiber/fiber/blob/main/color.go) | Active color scheme for the startup message. |
## OnFork
Runs in the child process after a fork.
```go title="Signature"
func (h *Hooks) OnFork(handler ...OnForkHandler)
```
## OnPreShutdown
Runs before the server shuts down.
```go title="Signature"
func (h *Hooks) OnPreShutdown(handler ...OnPreShutdownHandler)
```
## OnPostShutdown
Runs after the server shuts down.
```go title="Signature"
func (h *Hooks) OnPostShutdown(handler ...OnPostShutdownHandler)
```
## OnMount
Fires after a sub-app is mounted on a parent. The parent app is passed to the callback and it works for both app and group mounts.
```go title="Signature"
func (h *Hooks) OnMount(handler ...OnMountHandler)
```
```go
package main
import (
"fmt"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Get("/", testSimpleHandler).Name("x")
subApp := fiber.New()
subApp.Get("/test", testSimpleHandler)
subApp.Hooks().OnMount(func(parent *fiber.App) error {
fmt.Print("Mount path of parent app: " + parent.MountPath())
// Additional custom logic...
return nil
})
app.Use("/sub", subApp)
}
func testSimpleHandler(c fiber.Ctx) error {
return c.SendString("Hello, Fiber!")
}
// Result:
// Mount path of parent app: /sub
```
:::caution
OnName, OnRoute, OnGroup, and OnGroupName are mount-sensitive. When you mount a sub-app that registers these hooks, route and group paths include the mount prefix.
:::
================================================
FILE: docs/api/log.md
================================================
---
id: log
title: 📃 Log
description: Fiber's built-in log package
sidebar_position: 6
---
Logs help you observe program behavior, diagnose issues, and trigger alerts. Structured logs improve searchability and speed up troubleshooting.
Fiber logs to standard output by default and exposes global helpers such as `log.Info`, `log.Errorf`, and `log.Warnw`.
## Log Levels
```go
const (
LevelTrace Level = iota
LevelDebug
LevelInfo
LevelWarn
LevelError
LevelFatal
LevelPanic
)
```
## Custom Log
Fiber provides the generic `AllLogger[T]` interface for adapting various log libraries.
```go
type CommonLogger interface {
Logger
FormatLogger
WithLogger
}
type ConfigurableLogger[T any] interface {
// SetLevel sets logging level.
SetLevel(level Level)
// SetOutput sets the logger output.
SetOutput(w io.Writer)
// Logger returns the logger instance.
Logger() T
}
type AllLogger[T any] interface {
CommonLogger
ConfigurableLogger[T]
WithLogger
}
```
## Print Log
**Note:** The Fatal level method will terminate the program after printing the log message. Please use it with caution.
### Basic Logging
Call level-specific methods directly; entries use the `messageKey` (default `msg`).
```go
log.Info("Hello, World!")
log.Debug("Are you OK?")
log.Info("42 is the answer to life, the universe, and everything")
log.Warn("We are under attack!")
log.Error("Houston, we have a problem.")
log.Fatal("So Long, and Thanks for All the Fish.")
log.Panic("The system is down.")
```
### Formatted Logging
Append `f` to format the message.
```go
log.Debugf("Hello %s", "boy")
log.Infof("%d is the answer to life, the universe, and everything", 42)
log.Warnf("We are under attack, %s!", "boss")
log.Errorf("%s, we have a problem.", "John Smith")
log.Fatalf("So Long, and Thanks for All the %s.", "fish")
```
### Key-Value Logging
Key-value helpers log structured fields; mismatched pairs emit `KEYVALS UNPAIRED`.
```go
log.Debugw("", "greeting", "Hello", "target", "boy")
log.Infow("", "number", 42)
log.Warnw("", "job", "boss")
log.Errorw("", "name", "John Smith")
log.Fatalw("", "fruit", "fish")
```
## Global Log
Fiber also exposes a global logger for quick messages.
```go
import "github.com/gofiber/fiber/v3/log"
log.Info("info")
log.Warn("warn")
```
The example uses `log.DefaultLogger`, which writes to stdout. The [contrib](https://github.com/gofiber/contrib) repo offers adapters like `fiberzap` and `fiberzerolog`, or you can register your own with `log.SetLogger`.
Here's an example using a custom logger:
```go
import (
"log"
fiberlog "github.com/gofiber/fiber/v3/log"
)
var _ fiberlog.AllLogger[*log.Logger] = (*customLogger)(nil)
type customLogger struct {
stdlog *log.Logger
}
// Implement required methods for the AllLogger interface...
// Inject your custom logger
fiberlog.SetLogger[*log.Logger](&customLogger{
stdlog: log.New(os.Stdout, "CUSTOM ", log.LstdFlags),
})
// Retrieve the underlying *log.Logger for direct use
std := fiberlog.DefaultLogger[*log.Logger]().Logger()
std.Println("custom logging")
```
## Set Level
`log.SetLevel` sets the minimum level that will be output. The default is `LevelTrace`.
**Note:** This method is not concurrent safe.
```go
import "github.com/gofiber/fiber/v3/log"
log.SetLevel(log.LevelInfo)
```
Setting the log level allows you to control the verbosity of the logs, filtering out messages below the specified level.
## Set Output
`log.SetOutput` sets where logs are written. By default, they go to the console.
### Writing Logs to Stderr
```go
var logger fiberlog.AllLogger[*log.Logger] = &defaultLogger{
stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds),
depth: 4,
}
```
This lets you route logs to a file, service, or any destination.
### Writing Logs to a File
To write to a file such as `test.log`:
```go
// Output to ./test.log file
f, err := os.OpenFile("test.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
log.Fatal("Failed to open log file:", err)
}
log.SetOutput(f)
```
### Writing Logs to Both Console and File
Write to both `test.log` and `stdout`:
```go
// Output to ./test.log file
file, err := os.OpenFile("test.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
log.Fatal("Failed to open log file:", err)
}
iw := io.MultiWriter(os.Stdout, file)
log.SetOutput(iw)
```
## Bind Context
Bind a logger to a context with `log.WithContext`, which returns a `CommonLogger` tied to that context.
```go
commonLogger := log.WithContext(ctx)
commonLogger.Info("info")
```
Context binding adds request-specific data for easier tracing.
## Logger
Use `Logger` to access the underlying logger and call its native methods:
```go
logger := fiberlog.DefaultLogger[*log.Logger]() // Get the default logger instance
stdlogger := logger.Logger() // stdlogger is *log.Logger
stdlogger.SetFlags(0) // Hide timestamp by setting flags to 0
```
================================================
FILE: docs/api/redirect.md
================================================
---
id: redirect
title: 🔄 Redirect
description: Fiber's built-in redirect package
sidebar_position: 5
toc_max_heading_level: 5
---
Redirect helpers send the client to another URL or route.
## Redirect Methods
### To
Redirects to a URL built from the given path. Optionally set an HTTP [status](#status).
:::info
If unspecified, status defaults to **303 See Other**.
:::
```go title="Signature"
func (r *Redirect) To(location string) error
```
```go title="Example"
app.Get("/coffee", func(c fiber.Ctx) error {
// => HTTP - GET 301 /teapot
return c.Redirect().Status(fiber.StatusMovedPermanently).To("/teapot")
})
app.Get("/teapot", func(c fiber.Ctx) error {
return c.Status(fiber.StatusTeapot).Send("🍵 short and stout 🍵")
})
```
```go title="More examples"
app.Get("/", func(c fiber.Ctx) error {
// => HTTP - GET 303 /foo/bar
return c.Redirect().To("/foo/bar")
// => HTTP - GET 303 ../login
return c.Redirect().To("../login")
// => HTTP - GET 303 http://example.com
return c.Redirect().To("http://example.com")
// => HTTP - GET 301 https://example.com
return c.Redirect().Status(301).To("http://example.com")
})
```
### Route
Redirects to a named route with parameters and queries.
:::info
To send params and queries to a route, use the [`RedirectConfig`](#redirectconfig) struct.
:::
```go title="Signature"
func (r *Redirect) Route(name string, config ...RedirectConfig) error
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
// /user/fiber
return c.Redirect().Route("user", fiber.RedirectConfig{
Params: fiber.Map{
"name": "fiber",
},
})
})
app.Get("/with-queries", func(c fiber.Ctx) error {
// /user/fiber?data[0][name]=john&data[0][age]=10&test=doe
return c.Redirect().Route("user", fiber.RedirectConfig{
Params: fiber.Map{
"name": "fiber",
},
Queries: map[string]string{
"data[0][name]": "john",
"data[0][age]": "10",
"test": "doe",
},
})
})
app.Get("/user/:name", func(c fiber.Ctx) error {
return c.SendString(c.Params("name"))
}).Name("user")
```
### Back
Redirects to the referer. If it's missing, fall back to the provided URL. You can also set the status code.
:::info
If unspecified, status defaults to **303 See Other**.
:::
```go title="Signature"
func (r *Redirect) Back(fallback string) error
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Home page")
})
app.Get("/test", func(c fiber.Ctx) error {
c.Set("Content-Type", "text/html")
return c.SendString(`Back`)
})
app.Get("/back", func(c fiber.Ctx) error {
return c.Redirect().Back("/")
})
```
## Controls
:::info
Methods are **chainable**.
:::
### Status
Sets the HTTP status code for the redirect.
:::info
It is used in conjunction with [**To**](#to), [**Route**](#route), and [**Back**](#back) methods.
:::
```go title="Signature"
func (r *Redirect) Status(status int) *Redirect
```
```go title="Example"
app.Get("/coffee", func(c fiber.Ctx) error {
// => HTTP - GET 301 /teapot
return c.Redirect().Status(fiber.StatusMovedPermanently).To("/teapot")
})
```
### RedirectConfig
Sets the configuration for the redirect.
:::info
It is used in conjunction with the [**Route**](#route) method.
:::
```go title="Definition"
// RedirectConfig is a config to use with Redirect().Route()
type RedirectConfig struct {
Params fiber.Map // Route parameters
Queries map[string]string // Query map
}
```
### Flash Message
Similar to [Laravel](https://laravel.com/docs/11.x/redirects#redirecting-with-flashed-session-data), we can flash a message and retrieve it in the next request.
#### Messages
Retrieve all flash messages. See [With](#with) for details.
```go title="Signature"
func (r *Redirect) Messages() map[string]string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
messages := c.Redirect().Messages()
return c.JSON(messages)
})
```
#### Message
Get a flash message by key; see [With](#with).
```go title="Signature"
func (r *Redirect) Message(key string) *Redirect
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
message := c.Redirect().Message("status")
return c.SendString(message)
})
```
#### OldInputs
Retrieve stored input data. See [WithInput](#withinput).
```go title="Signature"
func (r *Redirect) OldInputs() map[string]string
```
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
oldInputs := c.Redirect().OldInputs()
return c.JSON(oldInputs)
})
```
#### OldInput
Get stored input data by key; see [WithInput](#withinput).
```go title="Signature"
func (r *Redirect) OldInput(key string) string
```
```go title="Example"
app.Get("/name", func(c fiber.Ctx) error {
oldInput := c.Redirect().OldInput("name")
return c.SendString(oldInput)
})
```
#### With
Send flash messages with `With`.
```go title="Signature"
func (r *Redirect) With(key, value string) *Redirect
```
```go title="Example"
app.Get("/login", func(c fiber.Ctx) error {
return c.Redirect().With("status", "Logged in successfully").To("/")
})
app.Get("/", func(c fiber.Ctx) error {
// => Logged in successfully
return c.SendString(c.Redirect().Message("status"))
})
```
#### WithInput
Send input data with `WithInput`, which stores them in a cookie.
It captures form, multipart, or query data depending on the request content type.
```go title="Signature"
func (r *Redirect) WithInput() *Redirect
```
```go title="Example"
// curl -X POST http://localhost:3000/login -d "name=John"
app.Post("/login", func(c fiber.Ctx) error {
return c.Redirect().WithInput().Route("name")
})
app.Get("/name", func(c fiber.Ctx) error {
// => John
return c.SendString(c.Redirect().OldInput("name"))
}).Name("name")
```
================================================
FILE: docs/api/services.md
================================================
---
id: services
title: 🧩 Services
sidebar_position: 9
---
Services wrap external dependencies. Register them in the application's state, and Fiber starts and stops them automatically—useful during development and testing.
After adding a service to the app configuration, Fiber starts it on launch and stops it during shutdown. Retrieve a service from state with `GetService` or `MustGetService` (see [State Management](./state)).
## Service Interface
The `Service` interface defines methods a service must implement.
### Definition
```go
type Service interface {
// Start starts the service, returning an error if it fails.
Start(ctx context.Context) error
// String returns a string representation of the service.
// It is used to print a human-readable name of the service in the startup message.
String() string
// State returns the current state of the service.
State(ctx context.Context) (string, error)
// Terminate terminates the service, returning an error if it fails.
Terminate(ctx context.Context) error
}
```
## Service Methods
### Start
Starts the service. Fiber calls this when the application starts.
```go
func (s *SomeService) Start(ctx context.Context) error
```
### String
Returns a string representation of the service, used to print the service in the startup message.
```go
func (s *SomeService) String() string
```
### State
Reports the current state of the service for the startup message.
```go
func (s *SomeService) State(ctx context.Context) (string, error)
```
### Terminate
Stops the service after the application shuts down using a post-shutdown hook.
```go
func (s *SomeService) Terminate(ctx context.Context) error
```
## Comprehensive Examples
### Example: Adding a Service
This example demonstrates how to add a Redis store as a service to the application, backed by the Testcontainers Redis Go module.
```go
package main
import (
"context"
"fmt"
"log"
"time"
"github.com/gofiber/fiber/v3"
"github.com/redis/go-redis/v9"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
)
const redisServiceName = "redis-store"
type redisService struct {
ctr *tcredis.RedisContainer
}
// Start initializes and starts the service. It implements the [fiber.Service] interface.
func (s *redisService) Start(ctx context.Context) error {
// start the service
c, err := tcredis.Run(ctx, "redis:latest")
if err != nil {
return err
}
s.ctr = c
return nil
}
// String returns a string representation of the service.
// It is used to print a human-readable name of the service in the startup message.
// It implements the [fiber.Service] interface.
func (s *redisService) String() string {
return redisServiceName
}
// State returns the current state of the service.
// It implements the [fiber.Service] interface.
func (s *redisService) State(ctx context.Context) (string, error) {
state, err := s.ctr.State(ctx)
if err != nil {
return "", fmt.Errorf("container state: %w", err)
}
return state.Status, nil
}
// Terminate stops and removes the service. It implements the [fiber.Service] interface.
func (s *redisService) Terminate(ctx context.Context) error {
// stop the service
return s.ctr.Terminate(ctx)
}
func main() {
cfg := &fiber.Config{}
// Initialize service.
cfg.Services = append(cfg.Services, &redisService{})
// Define a context provider for the services startup.
// This is useful to cancel the startup of the services if the context is canceled.
// Default is context.Background().
startupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cfg.ServicesStartupContextProvider = func() context.Context {
return startupCtx
}
// Define a context provider for the services shutdown.
// This is useful to cancel the shutdown of the services if the context is canceled.
// Default is context.Background().
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cfg.ServicesShutdownContextProvider = func() context.Context {
return shutdownCtx
}
app := fiber.New(*cfg)
ctx := context.Background()
// Obtain the Redis service from the application's State.
redisSrv, ok := fiber.GetService[*redisService](app.State(), redisServiceName)
if !ok || redisSrv == nil {
log.Printf("Redis service not found")
return
}
// Obtain the connection string from the service.
connString, err := redisSrv.ctr.ConnectionString(ctx)
if err != nil {
log.Printf("Could not get connection string: %v", err)
return
}
// Parse the connection string to create a Redis client.
options, err := redis.ParseURL(connString)
if err != nil {
log.Printf("failed to parse connection string: %s", err)
return
}
// Initialize the Redis client.
rdb := redis.NewClient(options)
// Check the Redis connection.
if err := rdb.Ping(ctx).Err(); err != nil {
log.Fatalf("Could not connect to Redis: %v", err)
}
app.Listen(":3000")
}
```
### Example: Add a service with the Store middleware
This example shows how to use services with the Store middleware for dependency injection. It uses a Redis store backed by the Testcontainers Redis module.
```go
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/logger"
redisStore "github.com/gofiber/storage/redis/v3"
"github.com/redis/go-redis/v9"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
)
const (
redisServiceName = "redis-store"
)
type User struct {
ID int `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
}
type redisService struct {
ctr *tcredis.RedisContainer
}
// Start initializes and starts the service. It implements the [fiber.Service] interface.
func (s *redisService) Start(ctx context.Context) error {
// start the service
c, err := tcredis.Run(ctx, "redis:latest")
if err != nil {
return err
}
s.ctr = c
return nil
}
// String returns a string representation of the service.
// It is used to print a human-readable name of the service in the startup message.
// It implements the [fiber.Service] interface.
func (s *redisService) String() string {
return redisServiceName
}
// State returns the current state of the service.
// It implements the [fiber.Service] interface.
func (s *redisService) State(ctx context.Context) (string, error) {
state, err := s.ctr.State(ctx)
if err != nil {
return "", fmt.Errorf("container state: %w", err)
}
return state.Status, nil
}
// Terminate stops and removes the service. It implements the [fiber.Service] interface.
func (s *redisService) Terminate(ctx context.Context) error {
// stop the service
return s.ctr.Terminate(ctx)
}
func main() {
cfg := &fiber.Config{}
// Initialize service.
cfg.Services = append(cfg.Services, &redisService{})
// Define a context provider for the services startup.
// This is useful to cancel the startup of the services if the context is canceled.
// Default is context.Background().
startupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cfg.ServicesStartupContextProvider = func() context.Context {
return startupCtx
}
// Define a context provider for the services shutdown.
// This is useful to cancel the shutdown of the services if the context is canceled.
// Default is context.Background().
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cfg.ServicesShutdownContextProvider = func() context.Context {
return shutdownCtx
}
app := fiber.New(*cfg)
// Initialize default config
app.Use(logger.New())
ctx := context.Background()
// Obtain the Redis service from the application's State.
redisSrv, ok := fiber.GetService[*redisService](app.State(), redisServiceName)
if !ok || redisSrv == nil {
log.Printf("Redis service not found")
return
}
// Obtain the connection string from the service.
connString, err := redisSrv.ctr.ConnectionString(ctx)
if err != nil {
log.Printf("Could not get connection string: %v", err)
return
}
// define a GoFiber session store, backed by the Redis service
store := redisStore.New(redisStore.Config{
URL: connString,
})
app.Post("/user/create", func(c fiber.Ctx) error {
var user User
if err := c.Bind().JSON(&user); err != nil {
return c.Status(fiber.StatusBadRequest).SendString(err.Error())
}
json, err := json.Marshal(user)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
// Save the user to the database.
err = store.Set(user.Email, json, time.Hour*24)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.JSON(user)
})
app.Get("/user/:id", func(c fiber.Ctx) error {
id := c.Params("id")
user, err := store.Get(id)
if err == redis.Nil {
return c.Status(fiber.StatusNotFound).SendString("User not found")
} else if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.JSON(string(user))
})
app.Listen(":3000")
}
```
================================================
FILE: docs/api/state.md
================================================
---
id: state
title: 🗂️ State Management
sidebar_position: 8
---
State management provides a global key–value store for application dependencies and runtime data. The store is shared across the entire application and persists between requests. It's commonly used to store [Services](../api/services), which you can retrieve with the `GetService` or `MustGetService` functions.
:::warning
When prefork is enabled, each worker process has an independent state store, meaning state is not shared between them.
:::
## State Type
`State` is a key–value store built on top of `sync.Map` to ensure safe concurrent access. It allows storage and retrieval of dependencies and configurations in a Fiber application as well as thread–safe access to runtime data.
### Definition
```go
// State is a key–value store for Fiber's app, used as a global storage for the app's dependencies.
// It is a thread–safe implementation of a map[string]any, using sync.Map.
type State struct {
dependencies sync.Map
}
```
## Methods on State
### Set
Set adds or updates a key–value pair in the State.
```go
// Set adds or updates a key–value pair in the State.
func (s *State) Set(key string, value any)
```
**Usage Example:**
```go
app.State().Set("appName", "My Fiber App")
```
### Get
Get retrieves a value from the State.
```go title="Signature"
func (s *State) Get(key string) (any, bool)
```
**Usage Example:**
```go
value, ok := app.State().Get("appName")
if ok {
fmt.Println("App Name:", value)
}
```
### MustGet
MustGet retrieves a value from the State and panics if the key is not found.
```go title="Signature"
func (s *State) MustGet(key string) any
```
**Usage Example:**
```go
appName := app.State().MustGet("appName")
fmt.Println("App Name:", appName)
```
### Has
Has checks if a key exists in the State.
```go title="Signature"
func (s *State) Has(key string) bool
```
**Usage Example:**
```go
if app.State().Has("appName") {
fmt.Println("App Name is set.")
}
```
### Delete
Delete removes a key–value pair from the State.
```go title="Signature"
func (s *State) Delete(key string)
```
**Usage Example:**
```go
app.State().Delete("obsoleteKey")
```
### Reset
Reset removes all keys from the State, including those related to Services.
```go title="Signature"
func (s *State) Reset()
```
**Usage Example:**
```go
app.State().Reset()
```
### Keys
Keys returns a slice containing all keys present in the State.
```go title="Signature"
func (s *State) Keys() []string
```
**Usage Example:**
```go
keys := app.State().Keys()
fmt.Println("State Keys:", keys)
```
### Len
Len returns the number of keys in the State.
```go
// Len returns the number of keys in the State.
func (s *State) Len() int
```
**Usage Example:**
```go
fmt.Printf("Total State Entries: %d\n", app.State().Len())
```
### GetString
GetString retrieves a string value from the State. It returns the string and a boolean indicating a successful type assertion.
```go title="Signature"
func (s *State) GetString(key string) (string, bool)
```
**Usage Example:**
```go
if appName, ok := app.State().GetString("appName"); ok {
fmt.Println("App Name:", appName)
}
```
### GetInt
GetInt retrieves an integer value from the State. It returns the int and a boolean indicating a successful type assertion.
```go title="Signature"
func (s *State) GetInt(key string) (int, bool)
```
**Usage Example:**
```go
if count, ok := app.State().GetInt("userCount"); ok {
fmt.Printf("User Count: %d\n", count)
}
```
### GetBool
GetBool retrieves a boolean value from the State. It returns the bool and a boolean indicating a successful type assertion.
```go title="Signature"
func (s *State) GetBool(key string) (value, bool)
```
**Usage Example:**
```go
if debug, ok := app.State().GetBool("debugMode"); ok {
fmt.Printf("Debug Mode: %v\n", debug)
}
```
### GetFloat64
GetFloat64 retrieves a float64 value from the State. It returns the float64 and a boolean indicating a successful type assertion.
```go title="Signature"
func (s *State) GetFloat64(key string) (float64, bool)
```
**Usage Example:**
```go title="Signature"
if ratio, ok := app.State().GetFloat64("scalingFactor"); ok {
fmt.Printf("Scaling Factor: %f\n", ratio)
}
```
### GetUint
GetUint retrieves a `uint` value from the State.
```go title="Signature"
func (s *State) GetUint(key string) (uint, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetUint("maxConnections"); ok {
fmt.Printf("Max Connections: %d\n", val)
}
```
### GetInt8
GetInt8 retrieves an `int8` value from the State.
```go title="Signature"
func (s *State) GetInt8(key string) (int8, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetInt8("threshold"); ok {
fmt.Printf("Threshold: %d\n", val)
}
```
### GetInt16
GetInt16 retrieves an `int16` value from the State.
```go title="Signature"
func (s *State) GetInt16(key string) (int16, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetInt16("minValue"); ok {
fmt.Printf("Minimum Value: %d\n", val)
}
```
### GetInt32
GetInt32 retrieves an `int32` value from the State.
```go title="Signature"
func (s *State) GetInt32(key string) (int32, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetInt32("portNumber"); ok {
fmt.Printf("Port Number: %d\n", val)
}
```
### GetInt64
GetInt64 retrieves an `int64` value from the State.
```go title="Signature"
func (s *State) GetInt64(key string) (int64, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetInt64("fileSize"); ok {
fmt.Printf("File Size: %d\n", val)
}
```
### GetUint8
GetUint8 retrieves a `uint8` value from the State.
```go title="Signature"
func (s *State) GetUint8(key string) (uint8, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetUint8("byteValue"); ok {
fmt.Printf("Byte Value: %d\n", val)
}
```
### GetUint16
GetUint16 retrieves a `uint16` value from the State.
```go title="Signature"
func (s *State) GetUint16(key string) (uint16, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetUint16("limit"); ok {
fmt.Printf("Limit: %d\n", val)
}
```
### GetUint32
GetUint32 retrieves a `uint32` value from the State.
```go title="Signature"
func (s *State) GetUint32(key string) (uint32, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetUint32("timeout"); ok {
fmt.Printf("Timeout: %d\n", val)
}
```
### GetUint64
GetUint64 retrieves a `uint64` value from the State.
```go title="Signature"
func (s *State) GetUint64(key string) (uint64, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetUint64("maxSize"); ok {
fmt.Printf("Max Size: %d\n", val)
}
```
### GetUintptr
GetUintptr retrieves a `uintptr` value from the State.
```go title="Signature"
func (s *State) GetUintptr(key string) (uintptr, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetUintptr("pointerValue"); ok {
fmt.Printf("Pointer Value: %d\n", val)
}
```
### GetFloat32
GetFloat32 retrieves a `float32` value from the State.
```go title="Signature"
func (s *State) GetFloat32(key string) (float32, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetFloat32("scalingFactor32"); ok {
fmt.Printf("Scaling Factor (float32): %f\n", val)
}
```
### GetComplex64
GetComplex64 retrieves a `complex64` value from the State.
```go title="Signature"
func (s *State) GetComplex64(key string) (complex64, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetComplex64("complexVal"); ok {
fmt.Printf("Complex Value (complex64): %v\n", val)
}
```
### GetComplex128
GetComplex128 retrieves a `complex128` value from the State.
```go title="Signature"
func (s *State) GetComplex128(key string) (complex128, bool)
```
**Usage Example:**
```go
if val, ok := app.State().GetComplex128("complexVal128"); ok {
fmt.Printf("Complex Value (complex128): %v\n", val)
}
```
## Generic Functions
Fiber provides generic functions to retrieve state values with type safety and fallback options.
### GetState
GetState retrieves a value from the State and casts it to the desired type. It returns the cast value and a boolean indicating if the cast was successful.
```go title="Signature"
func GetState[T any](s *State, key string) (T, bool)
```
**Usage Example:**
```go
// Retrieve an integer value safely.
userCount, ok := GetState[int](app.State(), "userCount")
if ok {
fmt.Printf("User Count: %d\n", userCount)
}
```
### MustGetState
MustGetState retrieves a value from the State and casts it to the desired type. It panics if the key is not found or if the type assertion fails.
```go title="Signature"
func MustGetState[T any](s *State, key string) T
```
**Usage Example:**
```go
// Retrieve the value or panic if it is not present.
config := MustGetState[string](app.State(), "configFile")
fmt.Println("Config File:", config)
```
### GetStateWithDefault
GetStateWithDefault retrieves a value from the State, casting it to the desired type. If the key is not present, it returns the provided default value.
```go title="Signature"
func GetStateWithDefault[T any](s *State, key string, defaultVal T) T
```
**Usage Example:**
```go
// Retrieve a value with a fallback.
requestCount := GetStateWithDefault[int](app.State(), "requestCount", 0)
fmt.Printf("Request Count: %d\n", requestCount)
```
### GetService
GetService retrieves a Service from the State and casts it to the desired type. It returns the cast value and a boolean indicating if the cast was successful.
```go title="Signature"
func GetService[T Service](s *State, key string) (T, bool) {
```
**Usage Example:**
```go
if srv, ok := fiber.GetService[*redisService](app.State(), "someService")
fmt.Printf("Some Service: %s\n", srv.String())
}
```
### MustGetService
MustGetService retrieves a Service from the State and casts it to the desired type. It panics if the key is not found or if the type assertion fails.
```go title="Signature"
func MustGetService[T Service](s *State, key string) T
```
**Usage Example:**
```go
srv := fiber.MustGetService[*SomeService](app.State(), "someService")
```
## Comprehensive Examples
### Example: Request Counter
This example demonstrates how to track the number of requests using the State.
```go
package main
import (
"fmt"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
// Initialize state with a counter.
app.State().Set("requestCount", 0)
// Middleware: Increase counter for every request.
app.Use(func(c fiber.Ctx) error {
count, _ := c.App().State().GetInt("requestCount")
app.State().Set("requestCount", count+1)
return c.Next()
})
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello World!")
})
app.Get("/stats", func(c fiber.Ctx) error {
count, _ := c.App().State().Get("requestCount")
return c.SendString(fmt.Sprintf("Total requests: %d", count))
})
app.Listen(":3000")
}
```
### Example: Environment–Specific Configuration
This example shows how to configure different settings based on the environment.
```go
package main
import (
"os"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
// Determine environment.
environment := os.Getenv("ENV")
if environment == "" {
environment = "development"
}
app.State().Set("environment", environment)
// Set environment-specific configurations.
if environment == "development" {
app.State().Set("apiUrl", "http://localhost:8080/api")
app.State().Set("debug", true)
} else {
app.State().Set("apiUrl", "https://api.production.com")
app.State().Set("debug", false)
}
app.Get("/config", func(c fiber.Ctx) error {
config := map[string]any{
"environment": environment,
"apiUrl": fiber.GetStateWithDefault(c.App().State(), "apiUrl", ""),
"debug": fiber.GetStateWithDefault(c.App().State(), "debug", false),
}
return c.JSON(config)
})
app.Listen(":3000")
}
```
### Example: Dependency Injection with State Management
This example demonstrates how to use the State for dependency injection in a Fiber application.
```go
package main
import (
"context"
"fmt"
"log"
"github.com/gofiber/fiber/v3"
"github.com/redis/go-redis/v9"
)
type User struct {
ID int `query:"id"`
Name string `query:"name"`
Email string `query:"email"`
}
func main() {
app := fiber.New()
ctx := context.Background()
// Initialize Redis client.
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "",
DB: 0,
})
// Check the Redis connection.
if err := rdb.Ping(ctx).Err(); err != nil {
log.Fatalf("Could not connect to Redis: %v", err)
}
// Inject the Redis client into Fiber's State for dependency injection.
app.State().Set("redis", rdb)
app.Get("/user/create", func(c fiber.Ctx) error {
var user User
if err := c.Bind().Query(&user); err != nil {
return c.Status(fiber.StatusBadRequest).SendString(err.Error())
}
// Retrieve the Redis client from the global state.
rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis")
if !ok {
return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found")
}
// Save the user to the database.
key := fmt.Sprintf("user:%d", user.ID)
err := rdb.HSet(ctx, key, "name", user.Name, "email", user.Email).Err()
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.JSON(user)
})
app.Get("/user/:id", func(c fiber.Ctx) error {
id := c.Params("id")
rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis")
if !ok {
return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found")
}
key := fmt.Sprintf("user:%s", id)
user, err := rdb.HGetAll(ctx, key).Result()
if err == redis.Nil {
return c.Status(fiber.StatusNotFound).SendString("User not found")
} else if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
return c.JSON(user)
})
app.Listen(":3000")
}
```
================================================
FILE: docs/client/_category_.json
================================================
{
"label": "\uD83C\uDF0E Client",
"position": 6,
"link": {
"type": "generated-index",
"description": "HTTP client for Fiber."
}
}
================================================
FILE: docs/client/examples.md
================================================
---
id: examples
title: 🍳 Examples
description: >-
Client usage examples.
sidebar_position: 5
---
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
## Basic Auth
Clients send credentials via the `Authorization` header, while the server
stores hashed passwords as shown in the middleware example.
```go
package main
import (
"encoding/base64"
"fmt"
"github.com/gofiber/fiber/v3/client"
)
func main() {
cc := client.New()
out := base64.StdEncoding.EncodeToString([]byte("john:doe"))
resp, err := cc.Get("http://localhost:3000", client.Config{
Header: map[string]string{
"Authorization": "Basic " + out,
},
})
if err != nil {
panic(err)
}
fmt.Print(string(resp.Body()))
}
```
```go
package main
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/basicauth"
)
func main() {
app := fiber.New()
app.Use(
basicauth.New(basicauth.Config{
Users: map[string]string{
// "doe" hashed using SHA-256
"john": "{SHA256}eZ75KhGvkY4/t0HfQpNPO1aO0tk6wd908bjUGieTKm8=",
},
}),
)
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
app.Listen(":3000")
}
```
## TLS
```go
package main
import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"github.com/gofiber/fiber/v3/client"
)
func main() {
cc := client.New()
certPool, err := x509.SystemCertPool()
if err != nil {
panic(err)
}
cert, err := os.ReadFile("ssl.cert")
if err != nil {
panic(err)
}
certPool.AppendCertsFromPEM(cert)
cc.SetTLSConfig(&tls.Config{
RootCAs: certPool,
})
resp, err := cc.Get("https://localhost:3000")
if err != nil {
panic(err)
}
fmt.Print(string(resp.Body()))
}
```
```go
package main
import (
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
err := app.Listen(":3000", fiber.ListenConfig{
CertFile: "ssl.cert",
CertKeyFile: "ssl.key",
})
if err != nil {
panic(err)
}
}
```
## Reusing fasthttp transports
The Fiber client can wrap existing `fasthttp` clients so that you can reuse
connection pools, custom dialers, or load-balancing logic that is already tuned
for your infrastructure.
### HostClient
```go
package main
import (
"log"
"time"
"github.com/gofiber/fiber/v3/client"
"github.com/valyala/fasthttp"
)
func main() {
hc := &fasthttp.HostClient{
Addr: "api.internal:443",
IsTLS: true,
MaxConnDuration: 30 * time.Second,
MaxIdleConnDuration: 10 * time.Second,
}
cc := client.NewWithHostClient(hc)
resp, err := cc.Get("https://api.internal:443/status")
if err != nil {
log.Fatal(err)
}
log.Printf("status=%d body=%s", resp.StatusCode(), resp.Body())
}
```
### LBClient
```go
package main
import (
"log"
"time"
"github.com/gofiber/fiber/v3/client"
"github.com/valyala/fasthttp"
)
func main() {
lb := &fasthttp.LBClient{
Timeout: 2 * time.Second,
Clients: []fasthttp.BalancingClient{
&fasthttp.HostClient{Addr: "edge-1.internal:8080"},
&fasthttp.HostClient{Addr: "edge-2.internal:8080"},
},
}
cc := client.NewWithLBClient(lb)
// Per-request overrides such as redirects, retries, TLS, and proxy dialers
// are shared across every host client managed by the load balancer.
resp, err := cc.Get("http://service.internal/api")
if err != nil {
log.Fatal(err)
}
log.Printf("status=%d body=%s", resp.StatusCode(), resp.Body())
}
```
## Cookie jar
The client can store and reuse cookies between requests by attaching a cookie jar.
### Request
```go
func main() {
jar := client.AcquireCookieJar()
defer client.ReleaseCookieJar(jar)
cc := client.New()
cc.SetCookieJar(jar)
jar.SetKeyValueBytes("httpbin.org", []byte("john"), []byte("doe"))
resp, err := cc.Get("https://httpbin.org/cookies")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
}
```
Click here to see the result
```json
{
"cookies": {
"john": "doe"
}
}
```
### Response
Read cookies set by the server directly from the jar.
```go
func main() {
jar := client.AcquireCookieJar()
defer client.ReleaseCookieJar(jar)
cc := client.New()
cc.SetCookieJar(jar)
_, err := cc.Get("https://httpbin.org/cookies/set/john/doe")
if err != nil {
panic(err)
}
uri := fasthttp.AcquireURI()
defer fasthttp.ReleaseURI(uri)
uri.SetHost("httpbin.org")
uri.SetPath("/cookies")
fmt.Println(jar.Get(uri))
}
```
Click here to see the result
```plaintext
[john=doe; path=/]
```
### Response (follow-up request)
```go
func main() {
jar := client.AcquireCookieJar()
defer client.ReleaseCookieJar(jar)
cc := client.New()
cc.SetCookieJar(jar)
_, err := cc.Get("https://httpbin.org/cookies/set/john/doe")
if err != nil {
panic(err)
}
resp, err := cc.Get("https://httpbin.org/cookies")
if err != nil {
panic(err)
}
fmt.Println(resp.String())
}
```
Click here to see the result
```json
{
"cookies": {
"john": "doe"
}
}
```
================================================
FILE: docs/client/hooks.md
================================================
---
id: hooks
title: 🎣 Hooks
description: >-
Hooks are used to manipulate the request/response process of the Fiber client.
sidebar_position: 4
---
Hooks let you intercept and modify the request or response flow of the Fiber client. They are useful for:
- Changing request parameters (e.g., URL, headers) before sending the request.
- Logging request and response details.
- Integrating complex tracing or monitoring tools.
- Handling authentication, retries, or other custom logic.
There are two kinds of hooks:
## Request Hooks
**Request hooks** are functions executed before the HTTP request is sent. They follow the signature:
```go
type RequestHook func(*Client, *Request) error
```
A request hook receives both the `Client` and the `Request` objects, allowing you to modify the request before it leaves your application. For example, you could:
- Change the host URL.
- Log request details (method, URL, headers).
- Add or modify headers or query parameters.
- Intercept and apply custom authentication logic.
**Example:**
```go
type Repository struct {
Name string `json:"name"`
FullName string `json:"full_name"`
Description string `json:"description"`
Homepage string `json:"homepage"`
Owner struct {
Login string `json:"login"`
} `json:"owner"`
}
func main() {
cc := client.New()
// Add a request hook that modifies the request URL before sending.
cc.AddRequestHook(func(c *client.Client, r *client.Request) error {
r.SetURL("https://api.github.com/" + r.URL())
return nil
})
resp, err := cc.Get("repos/gofiber/fiber")
if err != nil {
panic(err)
}
var repo Repository
if err := resp.JSON(&repo); err != nil {
panic(err)
}
fmt.Printf("Status code: %d\n", resp.StatusCode())
fmt.Printf("Repository: %s\n", repo.FullName)
fmt.Printf("Description: %s\n", repo.Description)
fmt.Printf("Homepage: %s\n", repo.Homepage)
fmt.Printf("Owner: %s\n", repo.Owner.Login)
fmt.Printf("Name: %s\n", repo.Name)
fmt.Printf("Full Name: %s\n", repo.FullName)
}
```
Click here to see the result
```plaintext
Status code: 200
Repository: gofiber/fiber
Description: ⚡️ Express inspired web framework written in Go
Homepage: https://gofiber.io
Owner: gofiber
Name: fiber
Full Name: gofiber/fiber
```
### Built-in Request Hooks
Fiber includes built-in request hooks:
- **parserRequestURL**: Normalizes and customizes the URL based on path and query parameters. Required for `PathParam` and `QueryParam` methods.
- **parserRequestHeader**: Sets request headers, cookies, content type, referer, and user agent based on client and request properties.
- **parserRequestBody**: Automatically serializes the request body (JSON, XML, form, file uploads, etc.).
:::info
If a request hook returns an error, Fiber stops the request and returns the error immediately.
:::
**Example with Multiple Hooks:**
```go
func main() {
cc := client.New()
cc.AddRequestHook(func(c *client.Client, r *client.Request) error {
fmt.Println("Hook 1")
return errors.New("error")
})
cc.AddRequestHook(func(c *client.Client, r *client.Request) error {
fmt.Println("Hook 2")
return nil
})
_, err := cc.Get("https://example.com/")
if err != nil {
panic(err)
}
}
```
Click here to see the result
```shell
Hook 1.
panic: error
goroutine 1 [running]:
main.main()
main.go:25 +0xaa
exit status 2
```
## Response Hooks
**Response hooks** are functions executed after the HTTP response is received. They follow the signature:
```go
type ResponseHook func(*Client, *Response, *Request) error
```
A response hook receives the `Client`, `Response`, and `Request` objects, allowing you to inspect and modify the response or perform additional actions such as logging, tracing, or processing response data.
**Example:**
```go
func main() {
cc := client.New()
cc.AddResponseHook(func(c *client.Client, resp *client.Response, req *client.Request) error {
fmt.Printf("Response Status Code: %d\n", resp.StatusCode())
fmt.Printf("HTTP protocol: %s\n\n", resp.Protocol())
fmt.Println("Response Headers:")
for key, value := range resp.RawResponse.Header.All() {
fmt.Printf("%s: %s\n", key, value)
}
return nil
})
_, err := cc.Get("https://example.com/")
if err != nil {
panic(err)
}
}
```
Click here to see the result
```plaintext
Response Status Code: 200
HTTP protocol: HTTP/1.1
Response Headers:
Content-Length: 1256
Content-Type: text/html; charset=UTF-8
Server: ECAcc (dcd/7D5A)
Age: 216114
Cache-Control: max-age=604800
Date: Fri, 10 May 2024 10:49:10 GMT
Etag: "3147526947+gzip+ident"
Expires: Fri, 17 May 2024 10:49:10 GMT
Last-Modified: Thu, 17 Oct 2019 07:18:26 GMT
Vary: Accept-Encoding
X-Cache: HIT
```
### Built-in Response Hooks
Fiber includes built-in response hooks:
- **parserResponseCookie**: Parses cookies from the response and stores them in the response object and cookie jar if available.
- **logger**: Logs information about the raw request and response. It uses the `log.CommonLogger` interface.
:::info
If a response hook returns an error, Fiber skips the remaining hooks and returns that error.
:::
**Example with Multiple Response Hooks:**
```go
func main() {
cc := client.New()
cc.AddResponseHook(func(c *client.Client, r1 *client.Response, r2 *client.Request) error {
fmt.Println("Hook 1")
return nil
})
cc.AddResponseHook(func(c *client.Client, r1 *client.Response, r2 *client.Request) error {
fmt.Println("Hook 2")
return errors.New("error")
})
cc.AddResponseHook(func(c *client.Client, r1 *client.Response, r2 *client.Request) error {
fmt.Println("Hook 3")
return nil
})
_, err := cc.Get("https://example.com/")
if err != nil {
panic(err)
}
}
```
Click here to see the result
```shell
Hook 1
Hook 2
panic: error
goroutine 1 [running]:
main.main()
main.go:30 +0xd6
exit status 2
```
## Hook Execution Order
Hooks run in FIFO order (first in, first out), so they're executed in the order you add them. Keep this in mind when adding multiple hooks, as the order can affect the outcome.
**Example:**
```go
func main() {
cc := client.New()
cc.AddRequestHook(func(c *client.Client, r *client.Request) error {
fmt.Println("Hook 1")
return nil
})
cc.AddRequestHook(func(c *client.Client, r *client.Request) error {
fmt.Println("Hook 2")
return nil
})
_, err := cc.Get("https://example.com/")
if err != nil {
panic(err)
}
}
```
Click here to see the result
```plaintext
Hook 1
Hook 2
```
================================================
FILE: docs/client/request.md
================================================
---
id: request
title: 📤 Request
description: >-
Request methods of Gofiber HTTP client.
sidebar_position: 2
---
The `Request` struct in Fiber's HTTP client represents an HTTP request. It encapsulates the data required to send a request, including:
- **URL**: The endpoint to which the request is sent.
- **Method**: The HTTP method (GET, POST, PUT, DELETE, etc.).
- **Headers**: Key-value pairs that provide additional information about the request or guide how the response should be processed.
- **Body**: The data sent with the request, commonly used with methods like POST and PUT.
- **Query Parameters**: Parameters appended to the URL to pass additional data or modify the request's behavior.
This structure is designed to be both flexible and efficient, allowing you to easily build and modify HTTP requests as needed.
```go
type Request struct {
url string
method string
userAgent string
boundary string
referer string
ctx context.Context
header *Header
params *QueryParam
cookies *Cookie
path *PathParam
timeout time.Duration
maxRedirects int
client *Client
body any
formData *FormData
files []*File
bodyType bodyType
RawRequest *fasthttp.Request
}
```
## REST Methods
### Get
**Get** sends a GET request to the specified URL. It sets the URL and HTTP method, then dispatches the request to the server.
```go title="Signature"
func (r *Request) Get(url string) (*Response, error)
```
### Post
**Post** sends a POST request. It sets the URL and method to POST, then sends the request.
```go title="Signature"
func (r *Request) Post(url string) (*Response, error)
```
### Put
**Put** sends a PUT request. It sets the URL and method to PUT, then sends the request.
```go title="Signature"
func (r *Request) Put(url string) (*Response, error)
```
### Patch
**Patch** sends a PATCH request. It sets the URL and method to PATCH, then sends the request.
```go title="Signature"
func (r *Request) Patch(url string) (*Response, error)
```
### Delete
**Delete** sends a DELETE request. It sets the URL and method to DELETE, then sends the request.
```go title="Signature"
func (r *Request) Delete(url string) (*Response, error)
```
### Head
**Head** sends a HEAD request. It sets the URL and method to HEAD, then sends the request.
```go title="Signature"
func (r *Request) Head(url string) (*Response, error)
```
### Options
**Options** sends an OPTIONS request. It sets the URL and method to OPTIONS, then sends the request.
```go title="Signature"
func (r *Request) Options(url string) (*Response, error)
```
### Custom
**Custom** sends a request using a custom HTTP method. For example, you can use this to send a TRACE or CONNECT request.
```go title="Signature"
func (r *Request) Custom(url, method string) (*Response, error)
```
## AcquireRequest
**AcquireRequest** returns a new pooled `Request`. Call `ReleaseRequest` when you're finished to return it to the pool and limit allocations.
```go title="Signature"
func AcquireRequest() *Request
```
## ReleaseRequest
**ReleaseRequest** returns the `Request` to the pool. Do not use it after releasing; doing so may cause data races.
```go title="Signature"
func ReleaseRequest(req *Request)
```
## Method
**Method** returns the current HTTP method set for the request.
```go title="Signature"
func (r *Request) Method() string
```
## SetMethod
**SetMethod** sets the HTTP method for the `Request` object. Typically, you should use the specialized request methods (`Get`, `Post`, etc.) instead of calling `SetMethod` directly.
```go title="Signature"
func (r *Request) SetMethod(method string) *Request
```
## URL
**URL** returns the current URL set in the `Request`.
```go title="Signature"
func (r *Request) URL() string
```
## SetURL
**SetURL** sets the URL for the `Request` object.
```go title="Signature"
func (r *Request) SetURL(url string) *Request
```
## Client
**Client** retrieves the `Client` instance associated with the `Request`.
```go title="Signature"
func (r *Request) Client() *Client
```
## SetClient
**SetClient** assigns a `Client` to the `Request`. If the provided client is `nil`, it will panic.
```go title="Signature"
func (r *Request) SetClient(c *Client) *Request
```
## Context
**Context** returns the `context.Context` of the request, or `context.Background()` if none is set.
```go title="Signature"
func (r *Request) Context() context.Context
```
## SetContext
**SetContext** sets the `context.Context` for the request, allowing you to cancel or time out the request. See the [Go blog](https://blog.golang.org/context) and [context](https://pkg.go.dev/context) docs for more details.
```go title="Signature"
func (r *Request) SetContext(ctx context.Context) *Request
```
## Header
**Header** returns all values for the specified header key. It searches all header fields stored in the request.
```go title="Signature"
func (r *Request) Header(key string) []string
```
### Headers
**Headers** returns an iterator over all headers in the request. Use `maps.Collect()` to transform them into a map if needed. The returned values are valid only until the request is released. Make copies as required.
```go title="Signature"
func (r *Request) Headers() iter.Seq2[string, []string]
```
Example
```go title="Example"
req := client.AcquireRequest()
req.AddHeader("Golang", "Fiber")
req.AddHeader("Test", "123456")
req.AddHeader("Test", "654321")
for k, v := range req.Headers() {
fmt.Printf("Header Key: %s, Header Value: %v\n", k, v)
}
```
```sh
Header Key: Golang, Header Value: [Fiber]
Header Key: Test, Header Value: [123456 654321]
```
Example with maps.Collect()
```go title="Example with maps.Collect()"
req := client.AcquireRequest()
req.AddHeader("Golang", "Fiber")
req.AddHeader("Test", "123456")
req.AddHeader("Test", "654321")
headers := maps.Collect(req.Headers()) // Collect all headers into a map
for k, v := range headers {
fmt.Printf("Header Key: %s, Header Value: %v\n", k, v)
}
```
```sh
Header Key: Golang, Header Value: [Fiber]
Header Key: Test, Header Value: [123456 654321]
```
### AddHeader
**AddHeader** adds a single header field and its value to the request.
```go title="Signature"
func (r *Request) AddHeader(key, val string) *Request
```
Example
```go title="Example"
req := client.AcquireRequest()
defer client.ReleaseRequest(req)
req.AddHeader("Golang", "Fiber")
req.AddHeader("Test", "123456")
req.AddHeader("Test", "654321")
resp, err := req.Get("https://httpbin.org/headers")
if err != nil {
panic(err)
}
fmt.Println(resp.String())
```
```json
{
"headers": {
"Golang": "Fiber",
"Host": "httpbin.org",
"Referer": "",
"Test": "123456,654321",
"User-Agent": "fiber",
"X-Amzn-Trace-Id": "Root=1-664105d2-033cf7173457adb56d9e7193"
}
}
```
### SetHeader
**SetHeader** sets a single header field and its value, overriding any previously set header with the same key.
```go title="Signature"
func (r *Request) SetHeader(key, val string) *Request
```
Example
```go title="Example"
req := client.AcquireRequest()
defer client.ReleaseRequest(req)
req.SetHeader("Test", "123456")
req.SetHeader("Test", "654321")
resp, err := req.Get("https://httpbin.org/headers")
if err != nil {
panic(err)
}
fmt.Println(resp.String())
```
```json
{
"headers": {
"Golang": "Fiber",
"Host": "httpbin.org",
"Referer": "",
"Test": "654321",
"User-Agent": "fiber",
"X-Amzn-Trace-Id": "Root=1-664105e5-5d676ba348450cdb62847f04"
}
}
```
### AddHeaders
**AddHeaders** adds multiple headers at once from a map of string slices.
```go title="Signature"
func (r *Request) AddHeaders(h map[string][]string) *Request
```
### SetHeaders
**SetHeaders** sets multiple headers at once from a map of strings, overriding any previously set headers.
```go title="Signature"
func (r *Request) SetHeaders(h map[string]string) *Request
```
## Param
**Param** returns all values associated with a given query parameter key.
```go title="Signature"
func (r *Request) Param(key string) []string
```
### Params
**Params** returns an iterator over all query parameters. Use `maps.Collect()` if you need them in a map. The returned values are valid only until the request is released.
```go title="Signature"
func (r *Request) Params() iter.Seq2[string, []string]
```
### AddParam
**AddParam** adds a single query parameter key-value pair.
```go title="Signature"
func (r *Request) AddParam(key, val string) *Request
```
Example
```go title="Example"
req := client.AcquireRequest()
defer client.ReleaseRequest(req)
req.AddParam("name", "john")
req.AddParam("hobbies", "football")
req.AddParam("hobbies", "basketball")
resp, err := req.Get("https://httpbin.org/response-headers")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
```
```json
{
"Content-Length": "145",
"Content-Type": "application/json",
"hobbies": [
"football",
"basketball"
],
"name": "joe"
}
```
### SetParam
**SetParam** sets a single query parameter key-value pair, overriding any previously set values for that key.
```go title="Signature"
func (r *Request) SetParam(key, val string) *Request
```
### AddParams
**AddParams** adds multiple query parameters from a map of string slices.
```go title="Signature"
func (r *Request) AddParams(m map[string][]string) *Request
```
### SetParams
**SetParams** sets multiple query parameters from a map of strings, overriding previously set values.
```go title="Signature"
func (r *Request) SetParams(m map[string]string) *Request
```
### SetParamsWithStruct
**SetParamsWithStruct** sets multiple query parameters from a struct. Nested structs are not supported.
```go title="Signature"
func (r *Request) SetParamsWithStruct(v any) *Request
```
Example
```go title="Example"
req := client.AcquireRequest()
defer client.ReleaseRequest(req)
req.SetParamsWithStruct(struct {
Name string `json:"name"`
Hobbies []string `json:"hobbies"`
}{
Name: "John Doe",
Hobbies: []string{
"Football",
"Basketball",
},
})
resp, err := req.Get("https://httpbin.org/response-headers")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
```
```json
{
"Content-Length": "147",
"Content-Type": "application/json",
"Hobbies": [
"Football",
"Basketball"
],
"Name": "John Doe"
}
```
### DelParams
**DelParams** removes one or more query parameters by their keys.
```go title="Signature"
func (r *Request) DelParams(key ...string) *Request
```
## UserAgent
**UserAgent** returns the user agent currently set in the request.
```go title="Signature"
func (r *Request) UserAgent() string
```
## SetUserAgent
**SetUserAgent** sets the user agent header for the request, overriding the one set at the client level if any.
```go title="Signature"
func (r *Request) SetUserAgent(ua string) *Request
```
## Boundary
**Boundary** returns the multipart boundary used by the request.
```go title="Signature"
func (r *Request) Boundary() string
```
## SetBoundary
**SetBoundary** sets the multipart boundary for file uploads.
```go title="Signature"
func (r *Request) SetBoundary(b string) *Request
```
## Referer
**Referer** returns the Referer header value currently set in the request.
```go title="Signature"
func (r *Request) Referer() string
```
## SetReferer
**SetReferer** sets the Referer header for the request, overriding the one set at the client level if any.
```go title="Signature"
func (r *Request) SetReferer(referer string) *Request
```
## Cookie
**Cookie** returns the value of the specified cookie. If the cookie does not exist, it returns an empty string.
```go title="Signature"
func (r *Request) Cookie(key string) string
```
### Cookies
**Cookies** returns an iterator over all cookies set in the request. Use `maps.Collect()` to gather them into a map.
```go title="Signature"
func (r *Request) Cookies() iter.Seq2[string, string]
```
### SetCookie
**SetCookie** sets a single cookie key-value pair, overriding any previously set cookie with the same key.
```go title="Signature"
func (r *Request) SetCookie(key, val string) *Request
```
### SetCookies
**SetCookies** sets multiple cookies from a map, overriding previously set values.
```go title="Signature"
func (r *Request) SetCookies(m map[string]string) *Request
```
Example
```go title="Example"
req := client.AcquireRequest()
defer client.ReleaseRequest(req)
req.SetCookies(map[string]string{
"cookie1": "value1",
"cookie2": "value2",
})
resp, err := req.Get("https://httpbin.org/cookies")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
```
```json
{
"cookies": {
"test": "123"
}
}
```
### SetCookiesWithStruct
**SetCookiesWithStruct** sets multiple cookies from a struct.
```go title="Signature"
func (r *Request) SetCookiesWithStruct(v any) *Request
```
### DelCookies
**DelCookies** removes one or more cookies by their keys.
```go title="Signature"
func (r *Request) DelCookies(key ...string) *Request
```
## PathParam
**PathParam** returns the value of a named path parameter. If not found, returns an empty string.
```go title="Signature"
func (r *Request) PathParam(key string) string
```
### PathParams
**PathParams** returns an iterator over all path parameters in the request. Use `maps.Collect()` to convert them into a map.
```go title="Signature"
func (r *Request) PathParams() iter.Seq2[string, string]
```
### SetPathParam
**SetPathParam** sets a single path parameter key-value pair, overriding previously set values.
```go title="Signature"
func (r *Request) SetPathParam(key, val string) *Request
```
Example
```go title="Example"
req := client.AcquireRequest()
defer client.ReleaseRequest(req)
req.SetPathParam("base64", "R29maWJlcg==")
resp, err := req.Get("https://httpbin.org/base64/:base64")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
```
```plaintext
Gofiber
```
### SetPathParams
**SetPathParams** sets multiple path parameters at once, overriding previously set values.
```go title="Signature"
func (r *Request) SetPathParams(m map[string]string) *Request
```
### SetPathParamsWithStruct
**SetPathParamsWithStruct** sets multiple path parameters from a struct.
```go title="Signature"
func (r *Request) SetPathParamsWithStruct(v any) *Request
```
### DelPathParams
**DelPathParams** deletes one or more path parameters by their keys.
```go title="Signature"
func (r *Request) DelPathParams(key ...string) *Request
```
### ResetPathParams
**ResetPathParams** deletes all path parameters.
```go title="Signature"
func (r *Request) ResetPathParams() *Request
```
## SetJSON
**SetJSON** sets the request body to a JSON-encoded payload.
```go title="Signature"
func (r *Request) SetJSON(v any) *Request
```
## SetXML
**SetXML** sets the request body to an XML-encoded payload.
```go title="Signature"
func (r *Request) SetXML(v any) *Request
```
## SetCBOR
**SetCBOR** sets the request body to a CBOR-encoded payload. It automatically sets the `Content-Type` to `application/cbor`.
```go title="Signature"
func (r *Request) SetCBOR(v any) *Request
```
## SetRawBody
**SetRawBody** sets the request body to raw bytes.
```go title="Signature"
func (r *Request) SetRawBody(v []byte) *Request
```
## FormData
**FormData** returns all values associated with the given form data field.
```go title="Signature"
func (r *Request) FormData(key string) []string
```
### AllFormData
**AllFormData** returns an iterator over all form data fields. Use `maps.Collect()` if needed.
```go title="Signature"
func (r *Request) AllFormData() iter.Seq2[string, []string]
```
### AddFormData
**AddFormData** adds a single form data key-value pair.
```go title="Signature"
func (r *Request) AddFormData(key, val string) *Request
```
Example
```go title="Example"
req := client.AcquireRequest()
defer client.ReleaseRequest(req)
req.AddFormData("points", "80")
req.AddFormData("points", "90")
req.AddFormData("points", "100")
resp, err := req.Post("https://httpbin.org/post")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
```
```json
{
"args": {},
"data": "",
"files": {},
"form": {
"points": [
"80",
"90",
"100"
]
},
// ...
}
```
### SetFormData
**SetFormData** sets a single form data field, overriding any previously set values.
```go title="Signature"
func (r *Request) SetFormData(key, val string) *Request
```
Example
```go title="Example"
req := client.AcquireRequest()
defer client.ReleaseRequest(req)
req.SetFormData("name", "john")
req.SetFormData("email", "john@doe.com")
resp, err := req.Post("https://httpbin.org/post")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
```
```json
{
"args": {},
"data": "",
"files": {},
"form": {
"email": "john@doe.com",
"name": "john"
},
// ...
}
```
### AddFormDataWithMap
**AddFormDataWithMap** adds multiple form data fields and values from a map of string slices.
```go title="Signature"
func (r *Request) AddFormDataWithMap(m map[string][]string) *Request
```
### SetFormDataWithMap
**SetFormDataWithMap** sets multiple form data fields from a map of strings.
```go title="Signature"
func (r *Request) SetFormDataWithMap(m map[string]string) *Request
```
### SetFormDataWithStruct
**SetFormDataWithStruct** sets multiple form data fields from a struct.
```go title="Signature"
func (r *Request) SetFormDataWithStruct(v any) *Request
```
### DelFormData
**DelFormData** deletes one or more form data fields by their keys.
```go title="Signature"
func (r *Request) DelFormData(key ...string) *Request
```
## File
**File** returns a file from the request by its name. If no name was provided, it attempts to match by path.
```go title="Signature"
func (r *Request) File(name string) *File
```
### Files
**Files** returns all files in the request as a slice. The returned slice is valid only until the request is released.
```go title="Signature"
func (r *Request) Files() []*File
```
### FileByPath
**FileByPath** returns a file from the request by its file path.
```go title="Signature"
func (r *Request) FileByPath(path string) *File
```
### AddFile
**AddFile** adds a single file to the request from a file path.
```go title="Signature"
func (r *Request) AddFile(path string) *Request
```
Example
```go title="Example"
req := client.AcquireRequest()
defer client.ReleaseRequest(req)
req.AddFile("test.txt")
resp, err := req.Post("https://httpbin.org/post")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
```
```json
{
"args": {},
"data": "",
"files": {
"file1": "This is an empty file!\n"
},
"form": {},
// ...
}
```
### AddFileWithReader
**AddFileWithReader** adds a single file to the request from an `io.ReadCloser`.
```go title="Signature"
func (r *Request) AddFileWithReader(name string, reader io.ReadCloser) *Request
```
Example
```go title="Example"
req := client.AcquireRequest()
defer client.ReleaseRequest(req)
buf := bytes.NewBuffer([]byte("Hello, World!"))
req.AddFileWithReader("test.txt", io.NopCloser(buf))
resp, err := req.Post("https://httpbin.org/post")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
```
```json
{
"args": {},
"data": "",
"files": {
"file1": "Hello, World!"
},
"form": {},
// ...
}
```
### AddFiles
**AddFiles** adds multiple files to the request at once.
```go title="Signature"
func (r *Request) AddFiles(files ...*File) *Request
```
## Timeout
**Timeout** returns the timeout duration set in the request.
```go title="Signature"
func (r *Request) Timeout() time.Duration
```
## SetTimeout
**SetTimeout** sets a timeout for the request, overriding any timeout set at the client level.
```go title="Signature"
func (r *Request) SetTimeout(t time.Duration) *Request
```
Example 1
```go title="Example 1"
req := client.AcquireRequest()
defer client.ReleaseRequest(req)
req.SetTimeout(5 * time.Second)
resp, err := req.Get("https://httpbin.org/delay/4")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
```
```json
{
"args": {},
"data": "",
"files": {},
"form": {},
// ...
}
```
Example 2
```go title="Example 2"
req := client.AcquireRequest()
defer client.ReleaseRequest(req)
req.SetTimeout(5 * time.Second)
resp, err := req.Get("https://httpbin.org/delay/6")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
```
```shell
panic: timeout or cancel
goroutine 1 [running]:
main.main()
main.go:18 +0xeb
exit status 2
```
## MaxRedirects
**MaxRedirects** returns the maximum number of redirects allowed for the request.
```go title="Signature"
func (r *Request) MaxRedirects() int
```
## SetMaxRedirects
**SetMaxRedirects** sets the maximum number of redirects for the request, overriding the client's setting.
```go title="Signature"
func (r *Request) SetMaxRedirects(count int) *Request
```
## Send
**Send** executes the HTTP request and returns a `Response`.
```go title="Signature"
func (r *Request) Send() (*Response, error)
```
## Reset
**Reset** clears the `Request` object, making it ready for reuse. This is used by `ReleaseRequest`.
```go title="Signature"
func (r *Request) Reset()
```
## Header
**Header** is a wrapper around `fasthttp.RequestHeader`, storing headers for both the client and request.
```go
type Header struct {
*fasthttp.RequestHeader
}
```
### PeekMultiple
**PeekMultiple** returns multiple values associated with the same header key.
```go title="Signature"
func (h *Header) PeekMultiple(key string) []string
```
### AddHeaders
**AddHeaders** adds multiple headers from a map of string slices.
```go title="Signature"
func (h *Header) AddHeaders(r map[string][]string)
```
### SetHeaders
**SetHeaders** sets multiple headers from a map of strings, overriding previously set headers.
```go title="Signature"
func (h *Header) SetHeaders(r map[string]string)
```
## QueryParam
**QueryParam** is a wrapper around `fasthttp.Args`, storing query parameters.
```go
type QueryParam struct {
*fasthttp.Args
}
```
### Keys
**Keys** returns all keys in the query parameters.
```go title="Signature"
func (p *QueryParam) Keys() []string
```
### AddParams
**AddParams** adds multiple query parameters from a map of string slices.
```go title="Signature"
func (p *QueryParam) AddParams(r map[string][]string)
```
### SetParams
**SetParams** sets multiple query parameters from a map of strings, overriding previously set values.
```go title="Signature"
func (p *QueryParam) SetParams(r map[string]string)
```
### SetParamsWithStruct
**SetParamsWithStruct** sets multiple query parameters from a struct. Nested structs are not supported.
```go title="Signature"
func (p *QueryParam) SetParamsWithStruct(v any)
```
## Cookie
**Cookie** is a map that stores cookies.
```go
type Cookie map[string]string
```
### Add
**Add** adds a cookie key-value pair.
```go title="Signature"
func (c Cookie) Add(key, val string)
```
### Del
**Del** removes a cookie by its key.
```go title="Signature"
func (c Cookie) Del(key string)
```
### SetCookie
**SetCookie** sets a single cookie key-value pair, overriding previously set values.
```go title="Signature"
func (c Cookie) SetCookie(key, val string)
```
### SetCookies
**SetCookies** sets multiple cookies from a map of strings.
```go title="Signature"
func (c Cookie) SetCookies(m map[string]string)
```
### SetCookiesWithStruct
**SetCookiesWithStruct** sets multiple cookies from a struct.
```go title="Signature"
func (c Cookie) SetCookiesWithStruct(v any)
```
### DelCookies
**DelCookies** deletes one or more cookies by their keys.
```go title="Signature"
func (c Cookie) DelCookies(key ...string)
```
### All
**All** returns an iterator over all cookies. The key and value returned
should not be retained after the loop ends.
```go title="Signature"
func (c Cookie) All() iter.Seq2[string, string]
```
### Reset
**Reset** clears all cookies.
```go title="Signature"
func (c Cookie) Reset()
```
## PathParam
**PathParam** is a map that stores path parameters.
```go
type PathParam map[string]string
```
### Add
**Add** adds a path parameter key-value pair.
```go title="Signature"
func (p PathParam) Add(key, val string)
```
### Del
**Del** removes a path parameter by its key.
```go title="Signature"
func (p PathParam) Del(key string)
```
### SetParam
**SetParam** sets a single path parameter key-value pair, overriding previously set values.
```go title="Signature"
func (p PathParam) SetParam(key, val string)
```
### SetParams
**SetParams** sets multiple path parameters from a map of strings.
```go title="Signature"
func (p PathParam) SetParams(m map[string]string)
```
### SetParamsWithStruct
**SetParamsWithStruct** sets multiple path parameters from a struct.
```go title="Signature"
func (p PathParam) SetParamsWithStruct(v any)
```
### DelParams
**DelParams** deletes one or more path parameters by their keys.
```go title="Signature"
func (p PathParam) DelParams(key ...string)
```
### All
**All** returns an iterator over all path parameters. The key and value returned
should not be retained after the loop ends.
```go title="Signature"
func (p PathParam) All() iter.Seq2[string, string]
```
### Reset
**Reset** clears all path parameters.
```go title="Signature"
func (p PathParam) Reset()
```
## FormData
**FormData** is a wrapper around `fasthttp.Args`, used to handle URL-encoded and form-data (multipart) request bodies.
```go
type FormData struct {
*fasthttp.Args
}
```
### Keys
**Keys** returns all form data keys.
```go title="Signature"
func (f *FormData) Keys() []string
```
### Add
**Add** adds a single form field key-value pair.
```go title="Signature"
func (f *FormData) Add(key, val string)
```
### Set
**Set** sets a single form field key-value pair, overriding any previously set values.
```go title="Signature"
func (f *FormData) Set(key, val string)
```
### AddWithMap
**AddWithMap** adds multiple form fields from a map of string slices.
```go title="Signature"
func (f *FormData) AddWithMap(m map[string][]string)
```
### SetWithMap
**SetWithMap** sets multiple form fields from a map of strings.
```go title="Signature"
func (f *FormData) SetWithMap(m map[string]string)
```
### SetWithStruct
**SetWithStruct** sets multiple form fields from a struct.
```go title="Signature"
func (f *FormData) SetWithStruct(v any)
```
### DelData
**DelData** deletes one or more form fields by their keys.
```go title="Signature"
func (f *FormData) DelData(key ...string)
```
### Reset
**Reset** clears all form data fields.
```go title="Signature"
func (f *FormData) Reset()
```
## File
**File** represents a file to be uploaded. It can be specified by name, path, or an `io.ReadCloser`.
```go
type File struct {
name string
fieldName string
path string
reader io.ReadCloser
}
```
### AcquireFile
**AcquireFile** returns a `File` from the pool and applies any provided `SetFileFunc` functions to it. Release it with `ReleaseFile` when done.
```go title="Signature"
func AcquireFile(setter ...SetFileFunc) *File
```
### ReleaseFile
**ReleaseFile** returns the `File` to the pool. Do not use the file afterward.
```go title="Signature"
func ReleaseFile(f *File)
```
### SetName
**SetName** sets the file's name.
```go title="Signature"
func (f *File) SetName(n string)
```
### SetFieldName
**SetFieldName** sets the field name of the file in the multipart form.
```go title="Signature"
func (f *File) SetFieldName(n string)
```
### SetPath
**SetPath** sets the file's path.
```go title="Signature"
func (f *File) SetPath(p string)
```
### SetReader
**SetReader** sets the file's `io.ReadCloser`. The reader is closed automatically when the request body is parsed.
```go title="Signature"
func (f *File) SetReader(r io.ReadCloser)
```
### Reset
**Reset** clears the file's fields.
```go title="Signature"
func (f *File) Reset()
```
================================================
FILE: docs/client/response.md
================================================
---
id: response
title: 📥 Response
description: >-
Response methods of Gofiber HTTP client.
sidebar_position: 3
---
The `Response` struct in Fiber's HTTP client represents the server's reply and exposes:
- **Status Code**: The HTTP status code returned by the server (e.g., `200 OK`, `404 Not Found`).
- **Headers**: All HTTP headers returned by the server, providing additional response-related information.
- **Body**: The response body content, which can be JSON, XML, plain text, or other formats.
- **Cookies**: Any cookies the server sent along with the response.
It makes it easy to inspect and handle data returned by the server.
```go
type Response struct {
client *Client
request *Request
cookie []*fasthttp.Cookie
RawResponse *fasthttp.Response
}
```
## AcquireResponse
**AcquireResponse** returns a new pooled `Response`. Call `ReleaseResponse` when you're done to return it to the pool and limit allocations.
```go title="Signature"
func AcquireResponse() *Response
```
## ReleaseResponse
**ReleaseResponse** puts the `Response` back into the pool. Do not use it after releasing; doing so can trigger data races.
```go title="Signature"
func ReleaseResponse(resp *Response)
```
## Status
**Status** returns the HTTP status message (e.g., `OK`, `Not Found`) associated with the response.
```go title="Signature"
func (r *Response) Status() string
```
## StatusCode
**StatusCode** returns the numeric HTTP status code of the response.
```go title="Signature"
func (r *Response) StatusCode() int
```
## Protocol
**Protocol** returns the HTTP protocol used (e.g., `HTTP/1.1`, `HTTP/2`) for the response.
```go title="Signature"
func (r *Response) Protocol() string
```
Example
```go title="Example"
resp, err := client.Get("https://httpbin.org/get")
if err != nil {
panic(err)
}
fmt.Println(resp.Protocol())
```
**Output:**
```text
HTTP/1.1
```
## Header
**Header** retrieves the value of a specific response header by key. If multiple values exist for the same header, this returns the first one.
```go title="Signature"
func (r *Response) Header(key string) string
```
## Headers
**Headers** returns an iterator over all response headers. Use `maps.Collect()` to convert them into a map if desired. The returned values are only valid until the response is released, so make copies if needed.
```go title="Signature"
func (r *Response) Headers() iter.Seq2[string, []string]
```
Example
```go title="Example"
resp, err := client.Get("https://httpbin.org/get")
if err != nil {
panic(err)
}
for key, values := range resp.Headers() {
fmt.Printf("%s => %s\n", key, strings.Join(values, ", "))
}
```
**Output:**
```text
Date => Wed, 04 Dec 2024 15:28:29 GMT
Connection => keep-alive
Access-Control-Allow-Origin => *
Access-Control-Allow-Credentials => true
```
Example with maps.Collect()
```go title="Example with maps.Collect()"
resp, err := client.Get("https://httpbin.org/get")
if err != nil {
panic(err)
}
headers := maps.Collect(resp.Headers())
for key, values := range headers {
fmt.Printf("%s => %s\n", key, strings.Join(values, ", "))
}
```
**Output:**
```text
Date => Wed, 04 Dec 2024 15:28:29 GMT
Connection => keep-alive
Access-Control-Allow-Origin => *
Access-Control-Allow-Credentials => true
```
## Cookies
**Cookies** returns a slice of all cookies set by the server in this response. The slice is only valid until the response is released.
```go title="Signature"
func (r *Response) Cookies() []*fasthttp.Cookie
```
Example
```go title="Example"
resp, err := client.Get("https://httpbin.org/cookies/set/go/fiber")
if err != nil {
panic(err)
}
cookies := resp.Cookies()
for _, cookie := range cookies {
fmt.Printf("%s => %s\n", string(cookie.Key()), string(cookie.Value()))
}
```
**Output:**
```text
go => fiber
```
## Body
**Body** returns the raw response body as a byte slice.
```go title="Signature"
func (r *Response) Body() []byte
```
## BodyStream
**BodyStream** returns the response body as an `io.Reader`, allowing incremental reading without loading the entire body into memory. This is particularly useful when `Client.SetStreamResponseBody(true)` is enabled.
When streaming is enabled, the underlying stream from fasthttp is returned directly. When streaming is not enabled, a `bytes.Reader` wrapping the body is returned as a fallback.
:::note
When using `BodyStream()`, the response body is consumed as you read. Calling `Body()` afterward may return an empty slice if the stream has been fully read.
:::
```go title="Signature"
func (r *Response) BodyStream() io.Reader
```
Example
```go title="Example"
cc := client.New()
cc.SetStreamResponseBody(true)
resp, err := cc.Get("https://httpbin.org/bytes/1024")
if err != nil {
panic(err)
}
defer resp.Close()
buf := make([]byte, 256)
total, err := io.CopyBuffer(io.Discard, resp.BodyStream(), buf)
if err != nil {
panic(err)
}
fmt.Printf("Read %d bytes\n", total)
```
**Output:**
```text
Read 1024 bytes
```
## IsStreaming
**IsStreaming** returns `true` if the response body is being streamed (i.e., when `Client.SetStreamResponseBody(true)` was set and the underlying transport provided a stream).
```go title="Signature"
func (r *Response) IsStreaming() bool
```
Example
```go title="Example"
cc := client.New()
cc.SetStreamResponseBody(true)
resp, err := cc.Get("https://httpbin.org/get")
if err != nil {
panic(err)
}
defer resp.Close()
if resp.IsStreaming() {
fmt.Println("Response is streaming")
// Use resp.BodyStream() to read incrementally
} else {
fmt.Println("Response is buffered")
// Use resp.Body() for direct access
}
```
## String
**String** returns the response body as a trimmed string.
```go title="Signature"
func (r *Response) String() string
```
## JSON
**JSON** unmarshal the response body into the provided variable `v` using JSON. `v` should be a pointer to a struct or a type compatible with JSON unmarshal.
```go title="Signature"
func (r *Response) JSON(v any) error
```
Example
```go title="Example"
type Body struct {
Slideshow struct {
Author string `json:"author"`
Date string `json:"date"`
Title string `json:"title"`
} `json:"slideshow"`
}
var out Body
resp, err := client.Get("https://httpbin.org/json")
if err != nil {
panic(err)
}
if err = resp.JSON(&out); err != nil {
panic(err)
}
fmt.Printf("%+v\n", out)
```
**Output:**
```text
{Slideshow:{Author:Yours Truly Date:date of publication Title:Sample Slide Show}}
```
## XML
**XML** unmarshal the response body into the provided variable `v` using XML decoding.
```go title="Signature"
func (r *Response) XML(v any) error
```
## CBOR
**CBOR** unmarshal the response body into `v` using CBOR decoding.
```go title="Signature"
func (r *Response) CBOR(v any) error
```
## Save
**Save** writes the response body to a file or an `io.Writer`. If `v` is a string, it interprets it as a file path, creates the file (and directories if needed), and writes the response to it. If `v` is an `io.Writer`, it writes directly to it.
```go title="Signature"
func (r *Response) Save(v any) error
```
## Reset
**Reset** clears the `Response` object, making it ready for reuse by `ReleaseResponse`.
```go title="Signature"
func (r *Response) Reset()
```
## Close
**Close** releases both the associated `Request` and `Response` objects back to their pools.
:::warning
After calling `Close`, any attempt to use the request or response may result in data races or undefined behavior. Ensure all processing is complete before closing.
:::
```go title="Signature"
func (r *Response) Close()
```
================================================
FILE: docs/client/rest.md
================================================
---
id: rest
title: 🖥️ REST
description: >-
HTTP client for Fiber.
sidebar_position: 1
toc_max_heading_level: 5
---
The Fiber Client is a high-performance HTTP client built on FastHTTP. It handles both internal service calls and external requests with minimal overhead.
## Features
- **Lightweight and fast**: built on FastHTTP for minimal overhead.
- **Flexible configuration**: set global defaults like timeouts or headers and override them per request.
- **Connection pooling**: reuses persistent connections instead of opening new ones.
- **Timeouts and retries**: supports per-request deadlines and retry policies for transient errors.
## Usage
Create a client with any required configuration, then send requests:
```go
package main
import (
"fmt"
"time"
"github.com/gofiber/fiber/v3/client"
)
func main() {
cc := client.New()
cc.SetTimeout(10 * time.Second)
// Send a GET request
resp, err := cc.Get("https://httpbin.org/get")
if err != nil {
panic(err)
}
fmt.Printf("Status: %d\n", resp.StatusCode())
fmt.Printf("Body: %s\n", string(resp.Body()))
}
```
See [examples](examples.md) for more detailed usage.
```go
type Client struct {
mu sync.RWMutex
fasthttp *fasthttp.Client
baseURL string
userAgent string
referer string
header *Header
params *QueryParam
cookies *Cookie
path *PathParam
debug bool
timeout time.Duration
// user-defined request hooks
userRequestHooks []RequestHook
// client package-defined request hooks
builtinRequestHooks []RequestHook
// user-defined response hooks
userResponseHooks []ResponseHook
// client package-defined response hooks
builtinResponseHooks []ResponseHook
jsonMarshal utils.JSONMarshal
jsonUnmarshal utils.JSONUnmarshal
xmlMarshal utils.XMLMarshal
xmlUnmarshal utils.XMLUnmarshal
cborMarshal utils.CBORMarshal
cborUnmarshal utils.CBORUnmarshal
cookieJar *CookieJar
// proxy
proxyURL string
// retry
retryConfig *RetryConfig
// logger
logger log.CommonLogger
}
```
### New
**New** creates and returns a new Client object.
```go title="Signature"
func New() *Client
```
### NewWithClient
**NewWithClient** creates and returns a new Client object from an existing `fasthttp.Client`.
```go title="Signature"
func NewWithClient(c *fasthttp.Client) *Client
```
## REST Methods
These helpers mirror axios-style method names and send HTTP requests using the configured client:
### Get
Sends a GET request.
```go title="Signature"
func (c *Client) Get(url string, cfg ...Config) (*Response, error)
```
### Post
Sends a POST request.
```go title="Signature"
func (c *Client) Post(url string, cfg ...Config) (*Response, error)
```
### Put
Sends a PUT request.
```go title="Signature"
func (c *Client) Put(url string, cfg ...Config) (*Response, error)
```
### Patch
Sends a PATCH request.
```go title="Signature"
func (c *Client) Patch(url string, cfg ...Config) (*Response, error)
```
### Delete
Sends a DELETE request.
```go title="Signature"
func (c *Client) Delete(url string, cfg ...Config) (*Response, error)
```
### Head
Sends a HEAD request.
```go title="Signature"
func (c *Client) Head(url string, cfg ...Config) (*Response, error)
```
### Options
Sends an OPTIONS request.
```go title="Signature"
func (c *Client) Options(url string, cfg ...Config) (*Response, error)
```
### Custom
Sends a request with any HTTP method.
```go title="Signature"
func (c *Client) Custom(url, method string, cfg ...Config) (*Response, error)
```
## Request Configuration
The `Config` type holds per-request parameters. JSON is used to serialize the body by default. If multiple body sources are set, precedence is:
1. Body
2. FormData
3. File
```go
type Config struct {
Ctx context.Context
UserAgent string
Referer string
Header map[string]string
Param map[string]string
Cookie map[string]string
PathParam map[string]string
Timeout time.Duration
MaxRedirects int
Body any
FormData map[string]string
File []*File
}
```
## R
**R** gets a `Request` object from the pool. Call `ReleaseRequest` when finished.
```go title="Signature"
func (c *Client) R() *Request
```
## Hooks
Hooks allow you to add custom logic before a request is sent or after a response is received.
### RequestHook
**RequestHook** returns user-defined request hooks.
```go title="Signature"
func (c *Client) RequestHook() []RequestHook
```
### ResponseHook
**ResponseHook** returns user-defined response hooks.
```go title="Signature"
func (c *Client) ResponseHook() []ResponseHook
```
### AddRequestHook
Adds one or more user-defined request hooks.
```go title="Signature"
func (c *Client) AddRequestHook(h ...RequestHook) *Client
```
### AddResponseHook
Adds one or more user-defined response hooks.
```go title="Signature"
func (c *Client) AddResponseHook(h ...ResponseHook) *Client
```
## JSON
### JSONMarshal
Returns the JSON marshaler function used by the client.
```go title="Signature"
func (c *Client) JSONMarshal() utils.JSONMarshal
```
### JSONUnmarshal
Returns the JSON unmarshaler function used by the client.
```go title="Signature"
func (c *Client) JSONUnmarshal() utils.JSONUnmarshal
```
### SetJSONMarshal
Sets a custom JSON marshaler.
```go title="Signature"
func (c *Client) SetJSONMarshal(f utils.JSONMarshal) *Client
```
### SetJSONUnmarshal
Sets a custom JSON unmarshaler.
```go title="Signature"
func (c *Client) SetJSONUnmarshal(f utils.JSONUnmarshal) *Client
```
## XML
### XMLMarshal
Returns the XML marshaler function used by the client.
```go title="Signature"
func (c *Client) XMLMarshal() utils.XMLMarshal
```
### XMLUnmarshal
Returns the XML unmarshaler function used by the client.
```go title="Signature"
func (c *Client) XMLUnmarshal() utils.XMLUnmarshal
```
### SetXMLMarshal
Sets a custom XML marshaler.
```go title="Signature"
func (c *Client) SetXMLMarshal(f utils.XMLMarshal) *Client
```
### SetXMLUnmarshal
Sets a custom XML unmarshaler.
```go title="Signature"
func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client
```
## CBOR
### CBORMarshal
Returns the CBOR marshaler function used by the client.
```go title="Signature"
func (c *Client) CBORMarshal() utils.CBORMarshal
```
### CBORUnmarshal
Returns the CBOR unmarshaler function used by the client.
```go title="Signature"
func (c *Client) CBORUnmarshal() utils.CBORUnmarshal
```
### SetCBORMarshal
Sets a custom CBOR marshaler.
```go title="Signature"
func (c *Client) SetCBORMarshal(f utils.CBORMarshal) *Client
```
### SetCBORUnmarshal
Sets a custom CBOR unmarshaler.
```go title="Signature"
func (c *Client) SetCBORUnmarshal(f utils.CBORUnmarshal) *Client
```
## TLS
### TLSConfig
Returns the client's TLS configuration. If none is set, it initializes a new
configuration with `MinVersion` defaulting to TLS 1.2.
```go title="Signature"
func (c *Client) TLSConfig() *tls.Config
```
### SetTLSConfig
Sets the TLS configuration for the client.
```go title="Signature"
func (c *Client) SetTLSConfig(config *tls.Config) *Client
```
### SetCertificates
Adds client certificates to the TLS configuration.
```go title="Signature"
func (c *Client) SetCertificates(certs ...tls.Certificate) *Client
```
### SetRootCertificate
Adds one or more root certificates to the client's trust store.
```go title="Signature"
func (c *Client) SetRootCertificate(path string) *Client
```
### SetRootCertificateFromString
Adds one or more root certificates from a string.
```go title="Signature"
func (c *Client) SetRootCertificateFromString(pem string) *Client
```
## SetProxyURL
Sets a proxy URL for the client. All subsequent requests will use this proxy.
```go title="Signature"
func (c *Client) SetProxyURL(proxyURL string) error
```
## Response Streaming
### StreamResponseBody
Returns whether response body streaming is enabled. When enabled, the response body is not fully loaded into memory and can be read as a stream using `Response.BodyStream()`. This is useful for handling large responses or server-sent events (SSE).
```go title="Signature"
func (c *Client) StreamResponseBody() bool
```
### SetStreamResponseBody
Enables or disables response body streaming. When enabled, responses can be consumed incrementally without loading the entire body into memory.
```go title="Signature"
func (c *Client) SetStreamResponseBody(enable bool) *Client
```
**Example:**
```go title="Example"
cc := client.New()
cc.SetStreamResponseBody(true)
resp, err := cc.Get("https://example.com/large-file")
if err != nil {
panic(err)
}
defer resp.Close()
// Check if response is streaming
if resp.IsStreaming() {
// Read body incrementally
reader := resp.BodyStream()
buf := make([]byte, 4096)
for {
n, err := reader.Read(buf)
if n > 0 {
// Process chunk...
}
if err == io.EOF {
break
}
if err != nil {
panic(err)
}
}
} else {
// Regular body access
body := resp.Body()
fmt.Println(string(body))
}
```
**Server-Sent Events Example:**
```go title="SSE Example"
cc := client.New()
cc.SetStreamResponseBody(true)
resp, err := cc.Get("https://example.com/events")
if err != nil {
panic(err)
}
defer resp.Close()
reader := bufio.NewReader(resp.BodyStream())
for {
line, err := reader.ReadString('\n')
if err == io.EOF {
break
}
if err != nil {
panic(err)
}
fmt.Print(line) // Process SSE event
}
```
## RetryConfig
Returns the retry configuration of the client.
```go title="Signature"
func (c *Client) RetryConfig() *RetryConfig
```
## SetRetryConfig
Sets the retry configuration for the client.
```go title="Signature"
func (c *Client) SetRetryConfig(config *RetryConfig) *Client
```
## BaseURL
### BaseURL
**BaseURL** returns the base URL currently set in the client.
```go title="Signature"
func (c *Client) BaseURL() string
```
### SetBaseURL
Sets a base URL prefix for all requests made by the client.
```go title="Signature"
func (c *Client) SetBaseURL(url string) *Client
```
**Example:**
```go title="Example"
cc := client.New()
cc.SetBaseURL("https://httpbin.org/")
resp, err := cc.Get("/get")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
```
Click here to see the result
```json
{
"args": {},
...
}
```
## Headers
### Header
Retrieves all values of a header key at the client level. The returned values apply to all requests.
```go title="Signature"
func (c *Client) Header(key string) []string
```
### AddHeader
Adds a single header to all requests initiated by this client.
```go title="Signature"
func (c *Client) AddHeader(key, val string) *Client
```
### SetHeader
Sets a single header, overriding any existing headers with the same key.
```go title="Signature"
func (c *Client) SetHeader(key, val string) *Client
```
### AddHeaders
Adds multiple headers at once, all applying to all future requests from this client.
```go title="Signature"
func (c *Client) AddHeaders(h map[string][]string) *Client
```
### SetHeaders
Sets multiple headers at once, overriding previously set headers.
```go title="Signature"
func (c *Client) SetHeaders(h map[string]string) *Client
```
## Query Parameters
### Param
Returns the values for a given query parameter key.
```go title="Signature"
func (c *Client) Param(key string) []string
```
### AddParam
Adds a single query parameter for all requests.
```go title="Signature"
func (c *Client) AddParam(key, val string) *Client
```
### SetParam
Sets a single query parameter, overriding previously set values.
```go title="Signature"
func (c *Client) SetParam(key, val string) *Client
```
### AddParams
Adds multiple query parameters from a map of string slices.
```go title="Signature"
func (c *Client) AddParams(m map[string][]string) *Client
```
### SetParams
Sets multiple query parameters from a map, overriding previously set values.
```go title="Signature"
func (c *Client) SetParams(m map[string]string) *Client
```
### SetParamsWithStruct
Sets multiple query parameters from a struct. Nested structs are not currently supported.
```go title="Signature"
func (c *Client) SetParamsWithStruct(v any) *Client
```
### DelParams
Deletes one or more query parameters.
```go title="Signature"
func (c *Client) DelParams(key ...string) *Client
```
## UserAgent & Referer
### SetUserAgent
Sets the user agent header for all requests.
```go title="Signature"
func (c *Client) SetUserAgent(ua string) *Client
```
### SetReferer
Sets the referer header for all requests.
```go title="Signature"
func (c *Client) SetReferer(r string) *Client
```
## Path Parameters
### PathParam
Returns the value of a named path parameter, if set.
```go title="Signature"
func (c *Client) PathParam(key string) string
```
### SetPathParam
Sets a single path parameter.
```go title="Signature"
func (c *Client) SetPathParam(key, val string) *Client
```
### SetPathParams
Sets multiple path parameters at once.
```go title="Signature"
func (c *Client) SetPathParams(m map[string]string) *Client
```
### SetPathParamsWithStruct
Sets multiple path parameters from a struct.
```go title="Signature"
func (c *Client) SetPathParamsWithStruct(v any) *Client
```
### DelPathParams
Deletes one or more path parameters.
```go title="Signature"
func (c *Client) DelPathParams(key ...string) *Client
```
## Cookies
### Cookie
Returns the value of a named cookie if set at the client level.
```go title="Signature"
func (c *Client) Cookie(key string) string
```
### SetCookie
Sets a single cookie for all requests.
```go title="Signature"
func (c *Client) SetCookie(key, val string) *Client
```
**Example:**
```go title="Example"
cc := client.New()
cc.SetCookie("john", "doe")
resp, err := cc.Get("https://httpbin.org/cookies")
if err != nil {
panic(err)
}
fmt.Println(string(resp.Body()))
```
Click here to see the result
```json
{
"cookies": {
"john": "doe"
}
}
```
### SetCookies
Sets multiple cookies at once.
```go title="Signature"
func (c *Client) SetCookies(m map[string]string) *Client
```
### SetCookiesWithStruct
Sets multiple cookies from a struct.
```go title="Signature"
func (c *Client) SetCookiesWithStruct(v any) *Client
```
### DelCookies
Deletes one or more cookies.
```go title="Signature"
func (c *Client) DelCookies(key ...string) *Client
```
## Timeout
### SetTimeout
Sets a default timeout for all requests, which can be overridden per request.
```go title="Signature"
func (c *Client) SetTimeout(t time.Duration) *Client
```
## Debugging
### Debug
Enables debug-level logging output.
```go title="Signature"
func (c *Client) Debug() *Client
```
### DisableDebug
Disables debug-level logging output.
```go title="Signature"
func (c *Client) DisableDebug() *Client
```
## Cookie Jar
### SetCookieJar
Assigns a cookie jar to the client to store and manage cookies across requests.
```go title="Signature"
func (c *Client) SetCookieJar(cookieJar *CookieJar) *Client
```
## Dial & Logger
### SetDial
Sets a custom dial function.
```go title="Signature"
func (c *Client) SetDial(dial fasthttp.DialFunc) *Client
```
### SetLogger
Sets the logger instance used by the client.
```go title="Signature"
func (c *Client) SetLogger(logger log.CommonLogger) *Client
```
### Logger
Returns the current logger instance.
```go title="Signature"
func (c *Client) Logger() log.CommonLogger
```
## Reset
### Reset
Clears and resets the client to its default state and reinstates the default
`fasthttp.Client` transport.
```go title="Signature"
func (c *Client) Reset()
```
## Default Client
Fiber provides a default client (created with `New()`). You can configure it or replace it as needed.
### C
**C** returns the default client.
```go title="Signature"
func C() *Client
```
### Get
Get is a convenience method that sends a GET request using the `defaultClient`.
```go title="Signature"
func Get(url string, cfg ...Config) (*Response, error)
```
### Post
Post is a convenience method that sends a POST request using the `defaultClient`.
```go title="Signature"
func Post(url string, cfg ...Config) (*Response, error)
```
### Put
Put is a convenience method that sends a PUT request using the `defaultClient`.
```go title="Signature"
func Put(url string, cfg ...Config) (*Response, error)
```
### Patch
Patch is a convenience method that sends a PATCH request using the `defaultClient`.
```go title="Signature"
func Patch(url string, cfg ...Config) (*Response, error)
```
### Delete
Delete is a convenience method that sends a DELETE request using the `defaultClient`.
```go title="Signature"
func Delete(url string, cfg ...Config) (*Response, error)
```
### Head
Head sends a HEAD request using the `defaultClient`, a convenience method.
```go title="Signature"
func Head(url string, cfg ...Config) (*Response, error)
```
### Options
Options is a convenience method that sends an OPTIONS request using the `defaultClient`.
```go title="Signature"
func Options(url string, cfg ...Config) (*Response, error)
```
### Replace
**Replace** replaces the default client with a new one. It returns a function that can restore the old client.
:::caution
Do not modify the default client concurrently.
:::
```go title="Signature"
func Replace(c *Client) func()
```
================================================
FILE: docs/extra/_category_.json
================================================
{
"label": "\uD83E\uDDE9 Extra",
"position": 8,
"link": {
"type": "generated-index",
"description": "Extra contents for Fiber."
}
}
================================================
FILE: docs/extra/benchmarks.md
================================================
---
id: benchmarks
title: 📊 Benchmarks
description: >-
These benchmarks aim to compare the performance of Fiber and other web
frameworks.
sidebar_position: 2
---
## TechEmpower
[TechEmpower](https://www.techempower.com/benchmarks/#section=test&runid=1d5bfc8a-5c4a-4fb2-a792-ad967f1eb138) provides a performance comparison of many web application frameworks that execute fundamental tasks such as JSON serialization, database access, and server-side template rendering.
Each framework runs under a realistic production configuration. Results are recorded on both cloud instances and physical hardware. The test implementations are community contributed and live in the [FrameworkBenchmarks repository](https://github.com/TechEmpower/FrameworkBenchmarks).
* Fiber `v3.0.0`
* 56 Cores Intel(R) Xeon(R) Gold 6330 CPU @ 2.00GHz (Three homogeneous ProLiant DL360 Gen10 Plus)
* 64GB RAM
* Enterprise SSD
* Ubuntu
* Mellanox Technologies MT28908 Family ConnectX-6 40Gbps Ethernet
### Plaintext
The Plaintext test measures basic request routing and demonstrates the capacity of high-performance platforms. Requests are pipelined, and the tiny response body demands high throughput to saturate the benchmark's gigabit Ethernet.
See [Plaintext requirements](https://github.com/TechEmpower/FrameworkBenchmarks/wiki/Project-Information-Framework-Tests-Overview#plaintext)
**Fiber** - **11,987,976** responses per second with an average latency of **1.0** ms.
**Express** - **1,204,969** responses per second with an average latency of **8.8** ms.


### Data Updates
**Fiber** handled **29,984** responses per second with an average latency of **16.9** ms.
**Express** handled **54,887** responses per second with an average latency of **9.2** ms.


### Multiple Queries
**Fiber** handled **54,002** responses per second with an average latency of **9.4** ms.
**Express** handled **85,011** responses per second with an average latency of **6.0** ms.


### Single Query
**Fiber** handled **953,016** responses per second with an average latency of **0.6** ms.
**Express** handled **441,543** responses per second with an average latency of **1.3** ms.


### JSON Serialization
**Fiber** handled **2,363,294** responses per second with an average latency of **0.2** ms.
**Express** handled **949,717** responses per second with an average latency of **0.5** ms.


================================================
FILE: docs/extra/faq.md
================================================
---
id: faq
title: 🤔 FAQ
description: >-
Frequently asked questions. Open an issue if you have another question to add.
sidebar_position: 1
---
## How should I structure my application?
There's no single answer; the ideal structure depends on your application's scale and team. Fiber makes no assumptions about project layout.
Routes and other application logic can live in any files or directories. For inspiration, see:
* [gofiber/boilerplate](https://github.com/gofiber/boilerplate)
* [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate)
* [Youtube - Building a REST API using Gorm and Fiber](https://www.youtube.com/watch?v=Iq2qT0fRhAA)
* [embedmode/fiberseed](https://github.com/embedmode/fiberseed)
## How do I handle custom 404 responses?
If you're using v2.32.0 or later, implement a custom error handler as shown below or read more at [Error Handling](../guide/error-handling.md#custom-error-handler).
If you're using v2.31.0 or earlier, the error handler will not capture 404 errors. Instead, add a middleware function at the very bottom of the stack \(below all other functions\) to handle a 404 response:
```go title="Example"
app.Use(func(c fiber.Ctx) error {
return c.Status(fiber.StatusNotFound).SendString("Sorry can't find that!")
})
```
## How can I use live reload?
[Air](https://github.com/air-verse/air) automatically restarts your Go application when source files change, speeding development.
To use Air in a Fiber project, follow these steps:
* Install Air by downloading the appropriate binary for your operating system from the GitHub release page or by building the tool from source.
* Create a configuration file for Air in your project directory, such as `.air.toml` or `air.conf`. Here's a sample configuration file that works with Fiber:
```toml
# .air.toml
root = "."
tmp_dir = "tmp"
[build]
cmd = "go build -o ./tmp/main ."
bin = "./tmp/main"
delay = 1000 # ms
exclude_dir = ["assets", "tmp", "vendor"]
include_ext = ["go", "tpl", "tmpl", "html"]
exclude_regex = ["_test\\.go"]
```
* Start your Fiber application with Air by running the following command:
```sh
air
```
As you edit source files, Air detects the changes and restarts the application.
A complete example is available in the [Fiber Recipes repository](https://github.com/gofiber/recipes/tree/master/air) and shows how to configure Air for a Fiber project.
## How do I set up an error handler?
To override the default error handler, provide a custom one in the [Config](../api/fiber.md#errorhandler) when creating a new [Fiber instance](../api/fiber.md#new).
```go title="Example"
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
},
})
```
We have a dedicated page explaining how error handling works in Fiber, see [Error Handling](../guide/error-handling.md).
## Which template engines does Fiber support?
Fiber currently supports 9 template engines in our [gofiber/template](https://docs.gofiber.io/template/) middleware:
* [ace](https://docs.gofiber.io/template/ace/)
* [amber](https://docs.gofiber.io/template/amber/)
* [django](https://docs.gofiber.io/template/django/)
* [handlebars](https://docs.gofiber.io/template/handlebars/)
* [html](https://docs.gofiber.io/template/html/)
* [jet](https://docs.gofiber.io/template/jet/)
* [mustache](https://docs.gofiber.io/template/mustache/)
* [pug](https://docs.gofiber.io/template/pug/)
* [slim](https://docs.gofiber.io/template/slim/)
To learn more about using Templates in Fiber, see [Templates](../guide/templates.md).
## Does Fiber have a community chat?
Yes, we have a [Discord](https://gofiber.io/discord) server with rooms for every topic.
If you have questions or just want to chat, join us via this [invite link](https://gofiber.io/discord).

## Does Fiber support subdomain routing?
Yes, we do. Here are some examples:
Example
```go
package main
import (
"log"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/logger"
)
type Host struct {
Fiber *fiber.App
}
func main() {
// Hosts
hosts := map[string]*Host{}
//-----
// API
//-----
api := fiber.New()
api.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
hosts["api.localhost:3000"] = &Host{api}
api.Get("/", func(c fiber.Ctx) error {
return c.SendString("API")
})
//------
// Blog
//------
blog := fiber.New()
blog.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
hosts["blog.localhost:3000"] = &Host{blog}
blog.Get("/", func(c fiber.Ctx) error {
return c.SendString("Blog")
})
//---------
// Website
//---------
site := fiber.New()
site.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
hosts["localhost:3000"] = &Host{site}
site.Get("/", func(c fiber.Ctx) error {
return c.SendString("Website")
})
// Server
app := fiber.New()
app.Use(func(c fiber.Ctx) error {
host := hosts[c.Hostname()]
if host == nil {
return c.SendStatus(fiber.StatusNotFound)
} else {
host.Fiber.Handler()(c.Context())
return nil
}
})
log.Fatal(app.Listen(":3000"))
}
```
For more information, see issue [#750](https://github.com/gofiber/fiber/issues/750).
## How can I handle conversions between Fiber and net/http?
Fiber can register common `net/http` handlers directly—just pass an
`http.Handler`, `http.HandlerFunc`, compatible function, or even a native
`fasthttp.RequestHandler` to your routing method. For other interoperability scenarios, the `adaptor` middleware provides
utilities for converting between Fiber and `net/http`. It allows seamless
integration of `net/http` handlers, middleware, and requests into Fiber
applications, and vice versa.
:::caution Performance trade-offs
Converted `net/http` handlers run through a compatibility layer. They won't expose
`fiber.Ctx` or Fiber-specific helpers, and the extra adaptation work makes them slower
than native Fiber handlers. Use them when interoperability matters, but prefer Fiber
handlers for maximum performance.
:::
For details on how to:
* Convert `net/http` handlers to Fiber handlers
* Convert Fiber handlers to `net/http` handlers
* Convert `fiber.Ctx` to `http.Request`
See the dedicated documentation: [Adaptor Documentation](../middleware/adaptor.md).
================================================
FILE: docs/extra/internal.md
================================================
---
title: 🏗️ Internal Architecture
description: >-
Learn about the internal architecture of Fiber, including the overall structure, request handling flow, routing, and path parsing.
sidebar_position: 3
---
## Overall Architecture
At the heart of Fiber is the **App** struct. It is responsible for configuring the server, managing a pool of Contexts (either our default implementation, **DefaultCtx**, or a user‑supplied **CustomCtx**), and holding the router stack with all registered routes and groups. In addition, the App contains mount fields to support sub‑applications and hooks that allow developers to run custom code at key stages (e.g. when registering routes or starting the server).
```mermaid
flowchart TD
A[App]
B["Configuration (Config)"]
C[Context Pool]
D["DefaultCtx \/ CustomCtx"]
E[Router Stack]
F["Groups & Routes"]
G["MountFields (Sub‑Apps)"]
H[Hooks]
A --> B
A --> C
C --> D
A --> E
E --> F
A --> G
A --> H
```
### Explanation
- App: The central object that bootstraps and runs the Fiber server.
- Configuration (Config): Contains settings for body limits, timeouts, TLS options, routing behavior (e.g. case‑sensitivity, strict routing), and more.
- Context Pool: A synchronized pool from which Contexts are acquired per request. This design minimizes allocations by recycling DefaultCtx (or CustomCtx) instances.
- Router Stack: Organizes all registered routes. It is later processed into a tree structure for fast route‑matching.
- MountFields: Support for mounting sub‑applications so that large APIs can be segmented into independent routers.
- Hooks: Allow for custom behavior at critical points (e.g., on route registration, route naming, on listen, on shutdown, etc.).
## Request Processing Flow
Fiber’s request processing is designed for performance and minimal overhead. When an HTTP request is received by the underlying fasthttp server, the flow is as follows:
1. Request Arrival: The fasthttp server receives the HTTP request.
2. Context Acquisition: The App calls AcquireCtx() to fetch a Context from the pool.
3. Context Reset: The acquired Context is reset (via DefaultCtx.Reset()) with the new request’s data.
4. Request Handling: The request handler (default or custom) is invoked.
5. Route Matching: The framework uses the next() (or nextCustom()) function to traverse the pre‑built route tree and find a matching route based on the URL and HTTP method.
6. Middleware Chain Execution: The matched route’s handler chain is executed in sequence.
7. Error Handling (if required): Any errors encountered trigger the registered error handler.
8. Response Generation: The response is sent back to the client.
9. Context Release: Finally, the Context is cleaned up and returned to the pool.
```mermaid
flowchart LR
R["HTTP Request (fasthttp)"]
A["App.RequestHandler (default or custom)"]
C["Acquire Context (from Pool)"]
X["Reset Context (DefaultCtx.Reset())"]
N["Route Matching (next() \/ nextCustom())"]
M["Handler Chain Execution"]
EH["Error Handling (if needed)"]
S["HTTP Response"]
RC["Release Context (to Pool)"]
R --> A
A --> C
C --> X
X --> N
N --> M
M --> EH
EH --> S
S --> RC
```
### Additional Note
Fiber minimizes memory allocations by reusing Context objects and uses an optimized route‑matching algorithm to rapidly determine the correct handler chain.
## Routing & Path Parsing
Fiber allows you to register routes using helper methods (e.g. Get(), Post()) or by creating groups and sub‑routers. Internally, the route pattern is parsed by the parseRoute() function. This function decomposes the route string into segments:
- Constant Segments: Fixed parts of the path (e.g. /api).
- Parameter Segments: Dynamic parts that begin with a colon. For example, a route may be defined as:
/api/\:userId<int>
Here, the segment \:userId<int> is a parameter segment with a type constraint (an integer).
- Constraints: Constraints (such as int, bool, datetime, or even regular expressions) are extracted from the parameter part and stored in the route’s metadata for validation at runtime.
```mermaid
flowchart TD
P["Route Pattern String (e.g., '/api/\\:userId\\<int>')"]
PA["parseRoute()"]
RP[routeParser]
RS["routeSegment(s)"]
C["Constraints (e.g., int, datetime, regex)"]
PARAM[Extracted Parameter Names]
P --> PA
PA --> RP
RP --> RS
RS --> C
RP --> PARAM
```
### Explanation
- parseRoute(): Takes a route string and returns a routeParser struct that includes a list of routeSegment objects.
- routeSegment: Represents a portion of the route. If it is a parameter segment, it may include constraints that determine the allowed format (for example, ensuring that a parameter is an integer).
- Extracted Parameter Names: These are later used to populate the request’s Context with the actual values parsed from the URL.
## Route Matching and Parameter Extraction
When a request is processed, Fiber uses its pre‑computed route tree (the treeStack) to efficiently match the incoming URL against registered routes.
1. Normalization: The URL is normalized (converted to lowercase, trailing slashes trimmed) to create a “detection path.”
2. Tree Traversal: The route tree, grouped by common prefixes, is traversed based on the HTTP method.
3. Matching: Constant segments are compared exactly, while parameter segments extract dynamic values.
4. Constraint Validation: Extracted parameter values are validated against any defined constraints.
```mermaid
flowchart TD
A["Incoming Request URL (e.g., '/api/john')"]
B["Normalize URL (lowercase, trim trailing slashes)"]
C["Detection Path"]
D["Traverse Route Tree (treeStack based on method)"]
E["Match Constant Segments"]
F["Identify Parameter Segments (e.g., ':userId')"]
G["Extract Parameter Values"]
H["Validate Constraints (e.g., 'int', 'datetime', 'regex')"]
I["Route Found"]
A --> B
B --> C
C --> D
D --> E
E --> F
F --> G
G --> H
H --> I
```
### Insight
This efficient matching mechanism leverages pre‑grouped routes to minimize comparisons, while dynamic segments allow for flexible URL structures and runtime validation.
## Middleware Chain Execution
Once a matching route is found, Fiber executes the chain of middleware and route handlers sequentially. The process is as follows:
1. Initial Handler Execution: The first handler of the matched route is invoked.
2. Calling Next(): Each handler calls Ctx.Next() to pass control to the next handler in the chain.
3. Termination: When no further handlers remain, the chain terminates and the response is sent.
```mermaid
flowchart TD
A[Matched Route]
B[Handler 1]
C[Handler 2]
D[Handler 3]
E[Response Generation]
A --> B
B -- "Calls C via Next()" --> C
C -- "Calls D via Next()" --> D
D -- "No Next() available" --> E
```
### Explanation
- Each handler in the chain can perform operations (e.g. authentication, logging, transformation) before calling Next() to forward control.
- This sequential processing ensures that middleware are executed in the order they were registered.
- If an error occurs or a handler does not call Next(), the chain may be terminated early, and an error handler may be invoked.
### Observations
Middleware are executed in the order they are registered. This sequential design allows each handler to perform tasks such as authentication, logging, or transformation before delegating to the next handler.
## Sub-Application Mounting & Grouping
Fiber allows mounting sub‑applications (or sub‑routers) under specific path prefixes. This enables modular design of large APIs. The mounting process works as follows:
1. Defining a Mount Point: A parent application (or group) calls `Use` with a sub-app, which triggers the internal mount path logic.
2. Merging Mount Fields: The sub‑app’s mount fields are updated with the prefix of the parent, and its routes are integrated into the parent’s routing structure.
3. Processing Sub‑App Routes: During startup, the parent app collects routes from mounted sub‑apps and builds a unified route tree.
```mermaid
flowchart TD
A[Parent App]
B["Sub-App (Mounted)"]
C["Define Mount Point (e.g. \'/admin\')"]
D["Update MountFields (assign mount path)"]
E["Merge Sub-App Routes (append to Router Stack)"]
F[Generate Unified Route Tree]
A --> C
C --> B
B --> D
D --> E
E --> F
```
### Impact
This mechanism enables large APIs to be broken down into smaller, maintainable modules while still benefiting from Fiber’s optimized routing and request handling.
## Route Tree Building
Fiber builds a route tree (the treeStack) to optimize route matching. This involves grouping routes based on a prefix (usually the first few characters) to reduce the number of comparisons during a request.
1. Iterating Over the Router Stack: Each registered route is examined.
2. Computing the Tree Key: A key is computed from the route’s normalized path (e.g. the first 3 characters).
3. Grouping Routes: Routes are added to the appropriate branch of the tree.
4. Sorting: Within each group, routes are sorted based on their registration order (or position) to ensure the correct match is found.
```mermaid
flowchart TD
A["Router Stack (All Registered Routes)"]
B["Compute Tree Key (e.g. first 3 characters)"]
C["Group Routes by Key (treeStack)"]
D["Merge Global Routes (key \'\' for global matches)"]
E[Sort Routes within Groups]
F[Optimized Route Tree]
A --> B
B --> C
C --> D
D --> E
E --> F
```
### Explanation
- Building a route tree is an optimization step that reduces the matching overhead by limiting the search space to a subset of routes that share a common prefix.
- The tree is rebuilt whenever new routes are registered, ensuring that the latest routing configuration is always used for matching.
## Context Lifecycle Management
Fiber minimizes allocations by pooling Context objects. The lifecycle of a Context is as follows:
1. **Acquisition:** When a new HTTP request arrives, a Context is retrieved from the pool via `App.AcquireCtx()`.
2. **Reset:** The acquired Context is reset with the current `fasthttp.RequestCtx` to clear previous data and initialize new request‑specific values.
3. **Processing:** The Context is passed along the middleware and handler chain.
4. **Release:** After processing the request (or when an error occurs), the Context is released back to the pool via `App.ReleaseCtx()`, making it available for reuse.
```mermaid
flowchart TD
A["HTTP Request (fasthttp)"]
B["Acquire Context (App.AcquireCtx())"]
C["Reset Context (DefaultCtx.Reset())"]
D["Process Request (Handlers & Middleware)"]
E["Error Handling (if needed)"]
F["Release Context (App.ReleaseCtx())"]
A --> B
B --> C
C --> D
D --> E
E --> F
```
### Key Benefit
Reusing Context objects significantly reduces garbage collection overhead, ensuring Fiber remains fast and memory‑efficient even under heavy load.
## Preforking Mechanism
To take full advantage of multi‑core systems, Fiber offers a prefork mode. In this mode, the master process spawns several child processes that listen on the same port using OS features such as SO_REUSEPORT (or fall back to SO_REUSEADDR).
```mermaid
flowchart LR
M["Master Process (App)"]
C[Child Processes]
GOMAX["Set GOMAXPROCS(1)"]
REQ[Handle HTTP Requests]
WM["watchMaster()"]
M -->|Spawns| C
C --> GOMAX
C -->|Processes| REQ
C --> WM
```
### Explanation
- Master Process: The main process determines the number of available CPU cores and spawns that many child processes.
- Child Processes: Each child sets GOMAXPROCS(1) to run on a single CPU core and listens on the shared port.
- watchMaster(): Each child process runs a watchdog routine to monitor the master process; if the master exits (or its parent process ID becomes 1 on Unix‑like systems), the child terminates gracefully.
### Detailed Preforking Workflow
Fiber’s prefork mode uses OS‑level mechanisms to allow multiple processes to listen on the same port. Here’s a more detailed look:
1. Master Process Spawning: The master process detects the number of CPU cores and spawns that many child processes.
2. Child Process Initialization: Each child process sets GOMAXPROCS(1) so that it runs on a single core.
3. Binding to Port: Child processes use packages like reuseport to bind to the same address and port.
4. Parent Monitoring: Each child runs a watchdog function (watchMaster()) to monitor the master process; if the master terminates, children exit.
5. Request Handling: Each child independently handles incoming HTTP requests.
```mermaid
flowchart TD
A[Master Process]
B[Determine CPU Cores]
C[Spawn Child Processes]
D["Child Process Initialization (GOMAXPROCS(1))"]
E["Bind to Port (reuseport)"]
F["Run watchMaster() (Monitor Parent)"]
G[Handle HTTP Requests]
A --> B
B --> C
C --> D
D --> E
E --> F
F --> G
```
#### Explanation
- Preforking improves performance by allowing multiple processes to handle requests concurrently.
- Using reuseport (or a fallback) ensures that all child processes can listen on the same port without conflicts.
- The watchdog routine in each child ensures that they exit if the master process is no longer running, maintaining process integrity.
## Redirection & Flash Messages
Fiber’s redirection mechanism is implemented via the Redirect struct. This structure allows not only setting a new location for redirection but also passing along flash messages and old input data via a special cookie.
```mermaid
flowchart LR
R[Redirect Struct]
RP[redirectPool]
FM["Flash Messages \/ Old Inputs"]
M["Methods: To(), Route(), Back()"]
LH[Set Location Header]
CK["Flash Cookie (fiber\_flash)"]
R -->|Acquired from| RP
R --> FM
R --> M
M --> LH
FM -->|Serialized| CK
```
### Explanation
- Redirect Struct: Retrieved from a pool (to minimize allocations), it stores redirection settings such as the HTTP status code (defaulting to 303 See Other) and any flash messages.
- Flash Messages & Old Inputs: These are collected via methods like With() or WithInput() and then serialized and stored in a cookie named fiber_flash.
- Redirection Methods: The To(), Route(), and Back() methods determine the target URL and set the Location header accordingly.
### Flash Message Handling in Redirection
When performing redirections, Fiber can send flash messages or preserve old input data. This process involves:
1. Collecting Flash Data: When a redirect is initiated, developers can add flash messages via Redirect.With() or old input data via Redirect.WithInput().
2. Serialization: The flash messages and input data are serialized (using a fast marshalling method) into a byte sequence.
3. Setting a Cookie: The serialized data is stored in a special cookie (named fiber_flash) that will be sent to the client.
4. Retrieval & Clearing: On the subsequent request, the flash data is read from the cookie, deserialized, and then cleared.
```mermaid
flowchart TD
A[Initiate Redirect]
B["Add Flash Messages (With(), WithInput())"]
C[Serialize Flash Data]
D["Set Flash Cookie (\'fiber\_flash\')"]
E[Client Receives Redirect]
F[Next Request Reads Flash Cookie]
G["Deserialize & Clear Flash Data"]
A --> B
B --> C
C --> D
D --> E
E --> F
F --> G
```
#### Explanation
- Flash messages provide a way to pass transient data (such as notifications or error messages) to the next request after a redirect.
- The data is stored temporarily in a cookie, which is then read and cleared upon processing the next request.
- This mechanism is essential for implementing post‑redirect‑get patterns and ensuring a smooth user experience.
## Hooks, Error Handling & Context Lifecycle
### Hooks
Fiber provides a comprehensive hook system that allows you to run custom functions at key moments:
- OnRoute: Called when a route is registered.
- OnName: Invoked when a route is assigned a name.
- OnGroup: Triggered when a group is created.
- OnListen: Runs when the server starts listening.
- OnShutdown: Called during graceful shutdown.
- OnFork: Invoked when a child process is forked.
- OnMount: Used when a sub‑application is mounted.
```mermaid
flowchart TD
H[Hooks]
OR[OnRoute]
ON[OnName]
OG[OnGroup]
OL[OnListen]
OS[OnShutdown]
OF[OnFork]
OM[OnMount]
H --> OR
H --> ON
H --> OG
H --> OL
H --> OS
H --> OF
H --> OM
```
#### Explanation
- Hooks provide extension points for developers and maintainers to inject custom logic without modifying the core Fiber code.
- They are executed at various stages (for example, every time a new route is registered, the OnRoute hooks are executed to allow for logging, validation, or transformation of the route).
### Error Handling & Context Lifecycle
Fiber’s DefaultCtx (or CustomCtx) represents the per‑request context. The lifecycle is as follows:
- Acquire: A Context is obtained from the pool at the beginning of a request.
- Processing: The context is passed along to the route handlers and middleware.
- Error Handling: If an error occurs (e.g., route not found, method not allowed, or a panic in the handler), Fiber calls the registered error handler. Errors such as ErrMethodNotAllowed or StatusNotFound are generated as needed.
- Release: Once the request is processed, the Context is released back into the pool for reuse.
```mermaid
flowchart LR
AC["Acquire Context (from Pool)"]
HP["Handle Request (Handlers & Middleware)"]
EH["Error Handling (if needed)"]
RC["Release Context (to Pool)"]
AC --> HP
HP --> EH
EH --> RC
```
#### Explanation
- This lifecycle ensures that Fiber minimizes allocations by reusing Context objects.
- Errors are propagated and handled consistently, and the context is properly reset after every request.
================================================
FILE: docs/extra/learning-resources.md
================================================
---
id: learning-resources
title: 📚 Learning Resources
description: >-
Interactive learning platforms and community resources to help you learn Fiber concepts through hands-on practice.
sidebar_position: 3
---
## Interactive Learning Platforms
Looking to practice Fiber concepts through hands-on exercises? Here are some community-driven learning resources:
### Go Interview Practice - Fiber Challenges
A comprehensive platform offering progressive Fiber challenges that complement the official documentation.

**What You'll Learn:**
- **High-Performance APIs** - Build ultra-fast RESTful APIs with zero-allocation routing
- **Middleware & Security** - Implement custom middleware, rate limiting, CORS, and authentication
- **Request Validation** - Input validation, error handling, and data transformation
- **Authentication & JWT** - Secure authentication systems with JWT tokens and API key validation

**Challenge Roadmap:**
1. **Basic Routing** - Setup Fiber, routes, and handlers (Beginner)
2. **Middleware & CORS** - Custom middleware and rate limiting (Intermediate)
3. **Validation & Errors** - Input validation and error handling (Intermediate)
4. **Authentication** - JWT tokens and API key validation (Advanced)


[Explore Fiber Challenges →](https://rezasi.github.io/go-interview-practice/fiber) | [GitHub Repository →](https://github.com/RezaSi/go-interview-practice)
================================================
FILE: docs/guide/_category_.json
================================================
{
"label": "\uD83D\uDCD6 Guide",
"position": 7,
"link": {
"type": "generated-index",
"description": "Guides for Fiber."
}
}
================================================
FILE: docs/guide/advance-format.md
================================================
---
id: advance-format
title: 🐛 Advanced Format
description: >-
Learn how to use MessagePack (MsgPack) and CBOR for efficient binary serialization in Fiber applications.
sidebar_position: 9
---
## MsgPack
Fiber lets you use MessagePack for efficient binary serialization. Use one of the popular Go libraries below to encode and decode data in handlers.
- Fiber can bind requests with the `application/vnd.msgpack` content type out of the box. See the [Binding documentation](../api/bind.md#msgpack) for details.
- Use `Bind().MsgPack()` to bind data to structs, similar to JSON. `Ctx.AutoFormat()` responds with MsgPack when the `Accept` header is `application/vnd.msgpack`. See the [AutoFormat documentation](../api/ctx.md#autoformat) for more.
### Recommended Libraries
- [github.com/vmihailenco/msgpack](https://pkg.go.dev/github.com/vmihailenco/msgpack) — A widely used, feature-rich MsgPack library.
- [github.com/shamaton/msgpack/v3](https://pkg.go.dev/github.com/shamaton/msgpack/v3) — High-performance MsgPack library.
### Installation
Install either library using:
```bash
go get github.com/vmihailenco/msgpack
# or
go get github.com/shamaton/msgpack/v3
```
> **Note:** Fiber doesn't bundle a MsgPack implementation because it's outside the Go standard library. Pick one of the popular libraries in the ecosystem; the two below are widely used and well maintained.
### Example: Using `shamaton/msgpack/v3`
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/shamaton/msgpack/v3"
)
type User struct {
Name string `msgpack:"name"` // tag may vary depending on your MsgPack library
Age int `msgpack:"age"`
}
func main() {
app := fiber.New(fiber.Config{
// Optional: Set custom MsgPack encoder/decoder
MsgPackEncoder: msgpack.Marshal,
MsgPackDecoder: msgpack.Unmarshal,
})
app.Post("/msgpack", func(c fiber.Ctx) error {
var user User
if err := c.Bind().MsgPack(&user); err != nil {
return err
}
// Content type will be set automatically to application/vnd.msgpack
return c.MsgPack(user)
})
app.Listen(":3000")
}
```
## CBOR
Fiber doesn't ship with a CBOR implementation. Use a library such as [fxamacker/cbor](https://github.com/fxamacker/cbor) to add encoding and decoding.
- Use `Bind().CBOR()` to bind CBOR to structs. `Ctx.AutoFormat()` replies with CBOR when the `Accept` header is `application/cbor`. See the [AutoFormat documentation](../api/ctx.md#autoformat) for details.
```bash
go get github.com/fxamacker/cbor/v2
```
Configure Fiber with the chosen library:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/fxamacker/cbor/v2"
)
func main() {
app := fiber.New(fiber.Config{
CBOREncoder: cbor.Marshal,
CBORDecoder: cbor.Unmarshal,
})
type User struct {
Name string `cbor:"name"`
Age int `cbor:"age"`
}
app.Post("/cbor", func(c fiber.Ctx) error {
var user User
if err := c.Bind().CBOR(&user); err != nil {
return err
}
// Content type will be set automatically to application/cbor
return c.CBOR(user)
})
app.Listen(":3000")
}
```
================================================
FILE: docs/guide/context.md
================================================
---
id: go-context
title: "\U0001F9E0 Go Context"
description: >-
Learn how Fiber's Ctx integrates with Go's context.Context,
how to interact with the underlying fasthttp RequestCtx,
and how to use the available context helpers.
sidebar_position: 6
toc_max_heading_level: 4
---
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
## Fiber Context as `context.Context`
Fiber's [`Ctx`](../api/ctx.md) implements Go's
[`context.Context`](https://pkg.go.dev/context#Context) interface.
You can pass `c` directly to functions that expect a `context.Context`
without adapters.
However, `fasthttp` doesn't support cancellation yet, so
`Deadline`, `Done`, and `Err` are no-ops.
:::caution
The `fiber.Ctx` instance is only valid within the lifetime of the handler.
It is reused for subsequent requests, so avoid storing `c` or using it in
goroutines that outlive the handler. For asynchronous work, call
`c.Context()` inside the handler to obtain a `context.Context` that can safely
be used after the handler returns. By default, this returns `context.Background()`
unless a custom context was provided with `c.SetContext`.
:::
```go title="Example"
func doSomething(ctx context.Context) {
// ... your logic here
}
app.Get("/", func(c fiber.Ctx) error {
doSomething(c) // c satisfies context.Context
return nil
})
```
### Using context outside the handler
`fiber.Ctx` is recycled after each request. If you need a context that lives
longer—for example, for work performed in a new goroutine—obtain it with
`c.Context()` before returning from the handler.
```go title="Async work"
app.Get("/job", func(c fiber.Ctx) error {
ctx := c.Context()
go performAsync(ctx)
return c.SendStatus(fiber.StatusAccepted)
})
```
You can customize the base context by calling `c.SetContext` before
requesting it:
```go
app.Get("/job", func(c fiber.Ctx) error {
c.SetContext(context.WithValue(context.Background(), "requestID", "123"))
ctx := c.Context()
go performAsync(ctx)
return nil
})
```
### Retrieving Values
`Ctx.Value` is backed by [Locals](../api/ctx.md#locals).
Values stored with `c.Locals` are accessible through `Value` or
standard `context.WithValue` helpers.
```go title="Locals and Value"
app.Get("/", func(c fiber.Ctx) error {
c.Locals("role", "admin")
role := c.Value("role") // returns "admin"
return c.SendString(role.(string))
})
```
## Working with `RequestCtx` and `fasthttpctx`
The underlying [`fasthttp.RequestCtx`](https://pkg.go.dev/github.com/valyala/fasthttp#RequestCtx)
can be accessed via `c.RequestCtx()`.
This exposes low-level APIs and the extra context support provided by
`fasthttpctx`.
```go title="Accessing RequestCtx"
app.Get("/raw", func(c fiber.Ctx) error {
fctx := c.RequestCtx()
// use fasthttp APIs directly
fctx.Response.Header.Set("X-Engine", "fasthttp")
return nil
})
```
`fasthttpctx` enables `fasthttp` to satisfy the `context.Context` interface.
`Deadline` always reports no deadline, `Done` is closed when the client
connection ends, and once it fires `Err` reports `context.Canceled`. This
means handlers can detect client disconnects while still passing
`c.RequestCtx()` into APIs that expect a `context.Context`.
## Context Helpers
Fiber and its middleware expose a number of helper functions that
retrieve request-scoped values from the context.
### Request ID
The RequestID middleware stores the generated identifier in the context.
Use `requestid.FromContext` to read it later.
```go
app.Use(requestid.New())
app.Get("/", func(c fiber.Ctx) error {
id := requestid.FromContext(c)
return c.SendString(id)
})
```
### CSRF
The CSRF middleware provides helpers to fetch the token or the handler
attached to the current context.
```go
app.Use(csrf.New())
app.Get("/form", func(c fiber.Ctx) error {
token := csrf.TokenFromContext(c)
return c.SendString(token)
})
```
```go title="Deleting a token"
app.Post("/logout", func(c fiber.Ctx) error {
handler := csrf.HandlerFromContext(c)
if handler != nil {
// Invalidate the token on logout
_ = handler.DeleteToken(c)
}
// ... other logout logic
return c.SendString("Logged out")
})
```
### Session
Sessions are stored on the context and can be retrieved via
`session.FromContext`.
```go
app.Use(session.New())
app.Get("/", func(c fiber.Ctx) error {
sess := session.FromContext(c)
count := sess.Get("visits")
return c.JSON(fiber.Map{"visits": count})
})
```
### Basic Authentication
After successful authentication, the username is available with
`basicauth.UsernameFromContext`. Passwords in `Users` must be pre-hashed.
```go
app.Use(basicauth.New(basicauth.Config{
Users: map[string]string{
// "secret" hashed using SHA-256
"admin": "{SHA256}K7gNU3sdo+OL0wNhqoVWhr3g6s1xYv72ol/pe/Unols=",
},
}))
app.Get("/", func(c fiber.Ctx) error {
user := basicauth.UsernameFromContext(c)
return c.SendString(user)
})
```
### Key Authentication
For API key authentication, the extracted token is stored in the
context and accessible via `keyauth.TokenFromContext`.
```go
app.Use(keyauth.New())
app.Get("/", func(c fiber.Ctx) error {
token := keyauth.TokenFromContext(c)
return c.SendString(token)
})
```
## Using `context.WithValue` and Friends
Since `fiber.Ctx` conforms to `context.Context`, standard helpers such as
`context.WithValue`, `context.WithTimeout`, or `context.WithCancel`
can wrap the request context when needed.
```go
app.Get("/job", func(c fiber.Ctx) error {
ctx, cancel := context.WithTimeout(c, 5*time.Second)
defer cancel()
// pass ctx to async operations that honor cancellation
if err := doWork(ctx); err != nil {
return err
}
return c.SendStatus(fiber.StatusOK)
})
```
### Context Cancellation with Goroutines in Fiber
When starting asynchronous work inside a handler, Fiber does not cancel the base `fiber.Ctx` automatically.
By wrapping the request context with `context.WithTimeout`, you can create a derived context that honors deadlines and cancellation signals.
The goroutine checks `ctx.Done()` before sending a result.
If the request times out or the client disconnects the goroutine exits early and avoids leaking resources.
The handler then waits for either:
- a result from the goroutine, or
- the `context timeout` (which returns a 504 Gateway Timeout)
This pattern ensures that long-running operations (database queries, external API calls, background tasks) do not continue running after the request has ended.
```go
func Handler(c fiber.Ctx) error {
ctx, cancel := context.WithTimeout(c.Context(), 2*time.Second)
defer cancel()
resultChan := make(chan string, 1)
go func() {
select {
case <-time.After(3 * time.Second):
select {
case <-ctx.Done():
return
case resultChan <- "done":
}
case <-ctx.Done():
return
}
}()
select {
case res := <-resultChan:
return c.SendString(res)
case <-ctx.Done():
return c.Status(fiber.StatusGatewayTimeout).SendString("timeout")
}
}
```
This approach provides safe cancellation semantics for goroutine-based work while allowing you to integrate Fiber handlers with context-aware APIs.
## Summary
- `fiber.Ctx` satisfies `context.Context` but its `Deadline`, `Done`, and `Err`
methods are currently no-ops.
- `RequestCtx` exposes the raw `fasthttp` context, whose `Done` channel closes
when the client connection ends.
- Use `fiber.StoreInContext(c, key, value)` to store request-scoped values in both
`c.Locals()` and `c.Context()` when values must be available through either API.
- Middleware helpers like `requestid.FromContext` or `session.FromContext`
make it easy to retrieve request-scoped data.
- Standard helpers such as `context.WithTimeout` can wrap `fiber.Ctx` to create
fully featured derived contexts inside handlers.
- `fiber.Config.PassLocalsToContext` controls whether Fiber context helpers
also propagate values into the request `context.Context` for Fiber-backed
contexts when using `StoreInContext`. It defaults to `false` for backward
compatibility, while `ValueFromContext` keeps reading from `c.Locals()`.
- Use `c.Context()` to obtain a `context.Context` that can outlive the handler,
and `c.SetContext()` to customize it with additional values or deadlines.
With these tools, you can seamlessly integrate Fiber applications with
Go's context-based APIs and manage request-scoped data effectively.
================================================
FILE: docs/guide/error-handling.md
================================================
---
id: error-handling
title: 🐛 Error Handling
description: >-
Fiber supports centralized error handling: handlers return errors so you can
log them or send a custom HTTP response to the client.
sidebar_position: 4
---
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
## Catching Errors
Return errors from route handlers and middleware so Fiber can handle them centrally.
```go
app.Get("/", func(c fiber.Ctx) error {
// Pass error to Fiber
return c.SendFile("file-does-not-exist")
})
```
Fiber does not recover from [panics](https://go.dev/blog/defer-panic-and-recover) by default. Add the `Recover` middleware to catch panics in any handler:
```go title="Example"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/recover"
)
func main() {
app := fiber.New()
app.Use(recover.New())
app.Get("/", func(c fiber.Ctx) error {
panic("This panic is caught by fiber")
})
log.Fatal(app.Listen(":3000"))
}
```
Use `fiber.NewError()` to create an error with a status code. If you omit the message, Fiber uses the standard status text (for example, `404` becomes `Not Found`).
```go title="Example"
app.Get("/", func(c fiber.Ctx) error {
// 503 Service Unavailable
return fiber.ErrServiceUnavailable
// 503 On vacation!
return fiber.NewError(fiber.StatusServiceUnavailable, "On vacation!")
})
```
## Default Error Handler
Fiber ships with a default error handler that sends **500 Internal Server Error** for generic errors. If the error is a [fiber.Error](https://godoc.org/github.com/gofiber/fiber#Error), the response uses the embedded status code and message.
```go title="Example"
// Default error handler
var DefaultErrorHandler = func(c fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Set Content-Type: text/plain; charset=utf-8
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
// Return status code with error message
return c.Status(code).SendString(err.Error())
}
```
## Custom Error Handler
Set a custom error handler in [`fiber.Config`](../api/fiber.md#errorhandler) when creating a new app.
The default handler covers most cases, but a custom handler lets you react to specific error types—for example, by logging to a service or sending a tailored JSON or HTML response.
The following example shows how to display error pages for different types of errors.
```go title="Example"
// Create a new fiber instance with custom config
app := fiber.New(fiber.Config{
// Override default error handler
ErrorHandler: func(ctx fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Send custom error page
err = ctx.Status(code).SendFile(fmt.Sprintf("./%d.html", code))
if err != nil {
// In case the SendFile fails
return ctx.Status(fiber.StatusInternalServerError).SendString("Internal Server Error")
}
// Return from handler
return nil
},
})
// ...
```
> Special thanks to the [Echo](https://echo.labstack.com/) and [Express](https://expressjs.com/) frameworks for inspiring parts of this error-handling approach.
================================================
FILE: docs/guide/extractors.md
================================================
---
id: extractors
title: 🔬 Extractors
description: Learn how to use extractors in Fiber middleware
sidebar_position: 8.5
toc_max_heading_level: 4
---
The extractors package provides shared value extraction utilities for Fiber middleware packages. It helps reduce code duplication across middleware packages while ensuring consistent behavior and security practices.
## Overview
The `github.com/gofiber/fiber/v3/extractors` module provides standardized value extraction utilities integrated into Fiber's middleware ecosystem. This approach:
- **Reduces Code Duplication**: Eliminates redundant extractor implementations across middleware packages
- **Ensures Consistency**: Maintains identical behavior and security practices across all extractors
- **Simplifies Maintenance**: Changes to extraction logic only need to be made in one place
- **Enables Direct Usage**: Middleware can import and use extractors directly
- **Improves Performance**: Shared, optimized extraction functions reduce overhead
## What Are Extractors?
Extractors are utilities that middleware uses to get values from different parts of HTTP requests:
### Available Extractors
- `FromAuthHeader(authScheme string)`: Extract from Authorization header with optional scheme
- `FromCookie(key string)`: Extract from HTTP cookies
- `FromParam(param string)`: Extract from URL path parameters
- `FromForm(param string)`: Extract from form data
- `FromHeader(header string)`: Extract from custom HTTP headers
- `FromQuery(param string)`: Extract from URL query parameters
- `FromCustom(key string, fn func(fiber.Ctx) (string, error))`: Define custom extraction logic with metadata
- `Chain(extractors ...Extractor)`: Chain multiple extractors with fallback logic
### Extractor Structure
Each `Extractor` contains:
```go
type Extractor struct {
Extract func(fiber.Ctx) (string, error) // Extraction function
Key string // Parameter/header name
Source Source // Source type for inspection
AuthScheme string // Auth scheme (FromAuthHeader)
Chain []Extractor // Chained extractors
}
```
- **Headers**: `Authorization`, `X-API-Key`, custom headers
- **Cookies**: Session cookies, authentication tokens
- **Query Parameters**: URL parameters like `?token=abc123`
- **Form Data**: POST body form fields
- **URL Parameters**: Route parameters like `/users/:id`
### Chain Behavior
The `Chain` function creates extractors that try multiple sources in order:
- Returns the first successful extraction (non-empty value with no error)
- If all extractors fail, returns the last error encountered or `ErrNotFound`
- **Robust error handling**: Skips extractors with `nil` Extract functions
- Preserves the source and key from the first extractor for metadata
- Stores a defensive copy of all chained extractors for introspection via the `Chain` field
## Why Middleware Uses Extractors
Middleware needs to extract values from requests for authentication, authorization, and other purposes. Extractors provide:
- **Security Awareness**: Different sources have different security implications
- **Fallback Support**: Try multiple sources if the first one doesn't have the value
- **Consistency**: Same extraction logic across all middleware packages
- **Source Tracking**: Know where values came from for security decisions
## Usage Examples
### Basic Usage
```go
// KeyAuth middleware extracts key from header
app.Use(keyauth.New(keyauth.Config{
Extractor: extractors.FromHeader("Middleware-Key"),
}))
```
### Fallback Chains
```go
// Try multiple sources in order
tokenExtractor := extractors.Chain(
extractors.FromHeader("Middleware-Key"), // Try header first
extractors.FromCookie("middleware_key"), // Then cookie
extractors.FromQuery("middleware_key"), // Finally query param
)
app.Use(keyauth.New(keyauth.Config{
Extractor: tokenExtractor,
}))
```
## Configuring Middleware That Uses Extractors
### Authentication Middleware
```go
// KeyAuth middleware (default: FromAuthHeader)
app.Use(keyauth.New(keyauth.Config{
// Default extracts from Authorization header
// Extractor: extractors.FromAuthHeader("Bearer"),
}))
// Custom header extraction
app.Use(keyauth.New(keyauth.Config{
Extractor: extractors.FromHeader("X-API-Key"),
}))
// Multiple sources with secure fallback
app.Use(keyauth.New(keyauth.Config{
Extractor: extractors.Chain(
extractors.FromAuthHeader("Bearer"), // Secure first
extractors.FromHeader("X-API-Key"), // Then custom header
extractors.FromQuery("api_key"), // Least secure last
),
}))
```
### Session Middleware
```go
// Session middleware (default: FromCookie)
app.Use(session.New(session.Config{
// Default extracts from session_id cookie
// Extractor: extractors.FromCookie("session_id"),
}))
// Custom cookie name
app.Use(session.New(session.Config{
Extractor: extractors.FromCookie("my_session"),
}))
```
### CSRF Middleware
```go
// CSRF middleware (default: FromHeader)
app.Use(csrf.New(csrf.Config{
// Default extracts from X-CSRF-Token header
// Extractor: extractors.FromHeader("X-CSRF-Token"),
}))
// Form-based CSRF (less secure, use only if needed)
app.Use(csrf.New(csrf.Config{
Extractor: extractors.Chain(
extractors.FromHeader("X-CSRF-Token"), // Secure first
extractors.FromForm("_csrf"), // Form fallback
),
}))
```
## Security Considerations
### Source Characteristics
Different extraction sources have different security properties and use cases:
#### Headers (Generally Preferred)
- **Authorization Header**: Standard for authentication tokens, widely supported
- **Custom Headers**: Application-specific, less likely to be logged by default
- **Considerations**: Can be intercepted without HTTPS, may be stripped by proxies
#### Cookies (Good for Sessions)
- **Session Cookies**: Designed for secure client-side storage
- **Considerations**: Require proper `Secure`, `HttpOnly`, and `SameSite` flags
- **Best for**: Session management, remember-me tokens
#### Query Parameters (Use Sparingly)
- **Query parameters**: Convenient for simple APIs and debugging
- **Considerations**: Always visible in URLs, logged by servers/proxies, stored in browser history
- **Best for**: Non-sensitive parameters, public identifiers
#### Form Data (Context Dependent)
- **POST Bodies**: Suitable for form submissions and API requests
- **Considerations**: Avoid putting sensitive data in query strings; ensure request bodies aren’t logged and use the correct content type
- **Best for**: User-generated content, file uploads
### Security Best Practices
1. **Use HTTPS**: Encrypt all traffic to protect extracted values in transit
2. **Validate Input**: Always validate and sanitize extracted values
3. **Log Carefully**: Avoid logging sensitive values from any source
4. **Choose Appropriate Sources**: Match the source to your security requirements
5. **Test Thoroughly**: Verify extraction works in your environment
6. **Monitor Security**: Watch for extraction failures or unusual patterns
### Chain Ordering Strategy
When using multiple sources, order them by your security preferences:
```go
// Example: Prefer headers, fall back to cookies, then query
extractors.Chain(
extractors.FromAuthHeader("Bearer"), // Standard auth
extractors.FromCookie("auth_token"), // Secure storage
extractors.FromQuery("token"), // Public fallback
)
```
The "best" source depends on your specific use case, security requirements, and application architecture.
### Common Security Issues
#### Leaky URLs
```go
// ❌ DON'T: API keys in URLs (visible in logs, history, bookmarks)
app.Use(keyauth.New(keyauth.Config{
Extractor: extractors.FromQuery("api_key"), // PROBLEMATIC
}))
// ✅ DO: API keys in headers (not visible in URLs)
app.Use(keyauth.New(keyauth.Config{
Extractor: extractors.FromHeader("X-API-Key"), // BETTER
}))
```
#### Session Tokens in Query Parameters
```go
// ❌ DON'T: Session tokens in URLs (can be bookmarked, leaked)
app.Use(session.New(session.Config{
Extractor: extractors.FromQuery("session"), // PROBLEMATIC
}))
// ✅ DO: Session tokens in cookies (designed for this purpose)
app.Use(session.New(session.Config{
Extractor: extractors.FromCookie("session_id"), // BETTER
}))
```
#### Form-Only CSRF Tokens
While the default extractor uses headers, some implementations use form fields, which is fine if you don't have AJAX or API clients:
```go
// ❌ DON'T: CSRF tokens only in forms (breaks AJAX, API calls)
app.Use(csrf.New(csrf.Config{
Extractor: extractors.FromForm("_csrf"), // LIMITED
}))
// ✅ DO: Header-first with form fallback (works everywhere)
app.Use(csrf.New(csrf.Config{
Extractor: extractors.Chain(
extractors.FromHeader("X-CSRF-Token"), // PREFERRED
extractors.FromForm("_csrf"), // FALLBACK
),
}))
```
### Understanding Trade-offs
**No extractor is universally "secure" - security depends on:**
- Whether you're using HTTPS
- How you configure cookies (Secure, HttpOnly, SameSite flags)
- Your logging and monitoring setup
- The sensitivity of the data being extracted
- Your threat model and security requirements
Choose extractors based on your specific use case and security needs, not blanket "secure" vs "insecure" labels.
## Standards Compliance
### Authorization Header (RFC 9110 & RFC 7235)
The `FromAuthHeader` extractor provides comprehensive RFC compliance with strict security validation:
#### RFC 9110 Compliance (Authorization Header Format)
- **Section 11.6.2 Format**: Enforces `credentials = auth-scheme 1*SP token68` structure
- **1*SP Requirement**: Validates exactly one or more spaces between auth-scheme and token
- **Case-insensitive scheme matching**: `Bearer`, `bearer`, `BEARER` all work correctly
- **Proper whitespace handling**: Rejects tabs between scheme and token (only spaces allowed)
#### RFC 7235 Token68 Validation
The extractor implements strict token68 character validation per RFC 7235:
- **Allowed characters**: `A-Z`, `a-z`, `0-9`, `-`, `.`, `_`, `~`, `+`, `/`, `=`
- **Padding rules**: `=` characters only allowed at the end of tokens
- **Security validation**: Prevents tokens starting with `=` or having non-padding characters after `=`
- **Whitespace rejection**: Rejects tokens containing spaces, tabs, or any other whitespace
#### Security Features
- **Header injection prevention**: Strict parsing prevents malformed authorization headers from bypassing authentication
- **Token validation**: Ensures extracted tokens conform to standards, preventing authentication bypass
- **Consistent error handling**: Returns `ErrNotFound` for all invalid cases
#### Examples
```go
// Standard usage - strict validation
extractor := extractors.FromAuthHeader("Bearer")
// ✅ Valid cases:
// "Bearer abc123" -> "abc123"
// "bearer ABC123" -> "ABC123" (case-insensitive scheme)
// "Bearer token123=" -> "token123=" (valid padding)
// "Bearer token==" -> "token==" (valid multiple padding)
// ❌ Invalid cases (all return ErrNotFound):
// "Bearer abc def" -> rejected (space in token)
// "Bearer abc\tdef" -> rejected (tab in token)
// "Bearer =abc" -> rejected (padding at start)
// "Bearer ab=cd" -> rejected (padding in middle)
// "Bearer token" -> rejected (multiple spaces after scheme)
// "Bearer\ttoken" -> rejected (tab after scheme)
// "Bearertoken" -> rejected (no space after scheme)
// Raw header extraction (no validation)
rawExtractor := extractors.FromAuthHeader("")
// "CustomAuth anything goes here" -> "CustomAuth anything goes here"
```
#### Benefits
- **Standards Compliance**: Full adherence to HTTP authentication RFCs
- **Security Hardening**: Prevents common authentication bypass vulnerabilities
- **Consistent Behavior**: Reliable parsing across different client implementations
- **Developer Confidence**: Clear validation rules reduce authentication bugs
## Troubleshooting
### Extraction Fails
**Problem**: Middleware returns "value not found" or authentication fails
**Solutions**:
1. Check if the expected header/cookie/query parameter is present
2. Verify the key name matches exactly (headers are case-insensitive; params/cookies/query keys are case-sensitive)
3. Ensure the request uses the correct HTTP method (GET vs POST)
4. Check if middleware is configured with the right extractor
**Debug Example**:
```go
// Add simple debug logging (avoid logging secrets in production)
app.Use(func(c fiber.Ctx) error {
hdr := c.Get("X-API-Key")
cookie := c.Cookies("session_id")
if hdr != "" || cookie != "" {
log.Printf("debug: X-API-Key present=%t, session_id present=%t", hdr != "", cookie != "")
}
return c.Next()
})
```
### Wrong Source Used
**Problem**: Values extracted from unexpected sources
**Solutions**:
1. Check middleware configuration order
2. Verify chain order (first successful extraction wins)
3. Use more specific extractors when needed
### Security Warnings
**Problem**: Getting security warnings in logs
**Solutions**:
1. Switch to more secure sources (headers/cookies)
2. Use HTTPS to encrypt traffic
3. Review if sensitive data should be in that source
## Advanced Usage
### Custom Extraction Logic
Extractors support custom extractors for complex scenarios:
```go
// Extract from custom logic (rarely needed)
customExtractor := extractors.FromCustom("my-source", func(c fiber.Ctx) (string, error) {
// Complex extraction logic
if value := c.Locals("computed_token"); value != nil {
return value.(string), nil
}
return "", extractors.ErrNotFound
})
```
:::warning
**Custom extractors break source awareness.** When you use `FromCustom`, middleware cannot determine where the value came from, which means:
- **No automatic security warnings** for potentially insecure sources
- **No source-based logging** or monitoring capabilities
- **Developer responsibility** for ensuring the extraction is secure and appropriate
**Only use `FromCustom` when:**
- Standard extractors don't meet your needs
- You've carefully evaluated the security implications
- You're confident in the security of your custom extraction logic
- You understand that middleware cannot provide source-aware security guidance
**Note:** If you pass `nil` as the function parameter, `FromCustom` will return an extractor that always fails with `ErrNotFound`.
:::
### Multiple Middleware Coordination
When using multiple middleware that extract values, ensure they don't conflict:
```go
// Good: Different sources for different purposes
app.Use(keyauth.New(keyauth.Config{
Extractor: extractors.FromHeader("X-API-Key"),
}))
app.Use(session.New(session.Config{
Extractor: extractors.FromCookie("session_id"),
}))
// Avoid: Same source for different middleware
app.Use(keyauth.New(keyauth.Config{
Extractor: extractors.FromCookie("token"), // API auth
}))
app.Use(session.New(session.Config{
Extractor: extractors.FromCookie("token"), // Session - CONFLICT!
}))
```
================================================
FILE: docs/guide/faster-fiber.md
================================================
---
id: faster-fiber
title: ⚡ Make Fiber Faster
sidebar_position: 7
---
## Custom JSON Encoder/Decoder
Fiber defaults to the standard `encoding/json` for stability and reliability. If you need more speed, consider these libraries:
- [goccy/go-json](https://github.com/goccy/go-json)
- [bytedance/sonic](https://github.com/bytedance/sonic)
- [segmentio/encoding](https://github.com/segmentio/encoding)
- [minio/simdjson-go](https://github.com/minio/simdjson-go)
```go title="Example"
package main
import "github.com/gofiber/fiber/v3"
import "github.com/goccy/go-json"
func main() {
app := fiber.New(fiber.Config{
JSONEncoder: json.Marshal,
JSONDecoder: json.Unmarshal,
})
// ...
}
```
### References
- [Set custom JSON encoder for client](../client/rest.md#setjsonmarshal)
- [Set custom JSON decoder for client](../client/rest.md#setjsonunmarshal)
- [Set custom JSON encoder for application](../api/fiber.md#jsonencoder)
- [Set custom JSON decoder for application](../api/fiber.md#jsondecoder)
================================================
FILE: docs/guide/grouping.md
================================================
---
id: grouping
title: 🎭 Grouping
sidebar_position: 2
---
:::info
Grouping works like Express.js. Groups are virtual; routes are flattened with the group's prefix and executed in declaration order, mirroring Express.js.
:::
## Paths
Groups can use path prefixes to organize related routes.
```go
func main() {
app := fiber.New()
api := app.Group("/api", middleware) // /api
v1 := api.Group("/v1", middleware) // /api/v1
v1.Get("/list", handler) // /api/v1/list
v1.Get("/user", handler) // /api/v1/user
v2 := api.Group("/v2", middleware) // /api/v2
v2.Get("/list", handler) // /api/v2/list
v2.Get("/user", handler) // /api/v2/user
log.Fatal(app.Listen(":3000"))
}
```
:::note
Group prefixes follow the same slash-boundary rule as `app.Use`. A prefix must either match the full path or stop at a `/`, so `/api` applies to `/api` and `/api/v1` but not `/apiv1`. Parameter markers (for example `:id`, `:id?`, `*`, and `+`) are processed before checking the boundary.
:::
Groups can also include an optional handler.
```go
func main() {
app := fiber.New()
api := app.Group("/api") // /api
v1 := api.Group("/v1") // /api/v1
v1.Get("/list", handler) // /api/v1/list
v1.Get("/user", handler) // /api/v1/user
v2 := api.Group("/v2") // /api/v2
v2.Get("/list", handler) // /api/v2/list
v2.Get("/user", handler) // /api/v2/user
log.Fatal(app.Listen(":3000"))
}
```
:::caution
Accessing `/api`, `/v1`, or `/v2` directly returns a **404**, so add error handlers as needed.
:::
## Group Handlers
Group handlers can act as routing paths but must call `Next` to continue the flow.
```go
func main() {
app := fiber.New()
handler := func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
}
api := app.Group("/api") // /api
v1 := api.Group("/v1", func(c fiber.Ctx) error { // middleware for /api/v1
c.Set("Version", "v1")
return c.Next()
})
v1.Get("/list", handler) // /api/v1/list
v1.Get("/user", handler) // /api/v1/user
log.Fatal(app.Listen(":3000"))
}
```
================================================
FILE: docs/guide/reverse-proxy.md
================================================
---
id: reverse-proxy
title: 🔄 Reverse Proxy Configuration
description: >-
Learn how to set up reverse proxies like Nginx or Traefik to enable modern
HTTP capabilities in your Fiber application, including HTTP/2 and
HTTP/3 (QUIC) support. This guide also covers basic reverse
proxy configuration and links to external documentation.
sidebar_position: 4
---
## Reverse Proxies
Running Fiber behind a reverse proxy is a common production setup.
Reverse proxies can handle:
- **HTTPS/TLS termination** (offloading SSL certificates)
- **Protocol upgrades** (HTTP/2, HTTP/3 support)
- **Request routing & load balancing**
- **Caching & compression**
- **Security features** (rate limiting, WAF, DDoS mitigation)
Some Fiber features (like [`SendEarlyHints`](../api/ctx.md#sendearlyhints)) require **HTTP/2 or newer**, which is easiest to enable using a reverse proxy.
### Popular Reverse Proxies
- [Nginx](https://nginx.org/)
- [Traefik](https://traefik.io/)
- [HA PROXY](https://www.haproxy.com/)
- [Caddy](https://caddyserver.com/)
## Getting the Real Client IP Address
When your Fiber application is behind a reverse proxy, the TCP connection comes from the proxy server, not the actual client. To get the real client IP address, you need to configure Fiber to read it from proxy headers like `X-Forwarded-For`.
:::warning Security Warning
Proxy headers can be easily spoofed by malicious clients. **Always** configure `TrustProxyConfig` to validate the proxy IP address, otherwise attackers can forge headers to bypass IP-based access controls, rate limiting, or geolocation features.
In addition, your reverse proxy should be configured to **set or overwrite** the forwarding header you choose (for example, `X-Forwarded-For`) based on the real client connection, or to use its real IP / PROXY protocol features. Do not simply pass through client-supplied forwarding headers, or `c.IP()` may still be controlled by an attacker even when `TrustProxyConfig` is correct.
:::
### Configuration
To enable reading the client IP from proxy headers, you must configure **three settings**:
1. **`TrustProxy`** - Enable proxy header trust (must be `true`)
2. **`ProxyHeader`** - Specify which header contains the client IP
3. **`TrustProxyConfig`** - Define which proxy IPs to trust
```go title="Example - App Behind Nginx"
app := fiber.New(fiber.Config{
// Enable proxy support
TrustProxy: true,
// Read client IP from X-Forwarded-For header
ProxyHeader: fiber.HeaderXForwardedFor,
// Trust requests from your Nginx proxy
TrustProxyConfig: fiber.TrustProxyConfig{
// Option 1: Trust specific proxy IPs
Proxies: []string{"10.10.0.58", "192.168.1.0/24"},
// Option 2: Or trust all private IPs (useful for internal load balancers)
// Private: true,
},
})
```
### Common Proxy Headers
Different proxies use different headers:
| Proxy/Service | Recommended Header | Config Value |
|---------------|-------------------|--------------|
| Nginx, HAProxy, Apache | X-Forwarded-For | `fiber.HeaderXForwardedFor` |
| Cloudflare | CF-Connecting-IP | `"CF-Connecting-IP"` |
| Fastly | Fastly-Client-IP | `"Fastly-Client-IP"` |
| Generic | X-Real-IP | `"X-Real-IP"` |
### TrustProxyConfig Options
The `TrustProxyConfig` struct provides multiple ways to specify trusted proxies:
```go
TrustProxyConfig: fiber.TrustProxyConfig{
// Specific IPs or CIDR ranges
Proxies: []string{
"10.10.0.58", // Single IP
"192.168.0.0/24", // CIDR range
"2001:db8::/32", // IPv6 range
},
// Or use convenience flags:
Loopback: true, // Trust 127.0.0.0/8, ::1/128
LinkLocal: true, // Trust 169.254.0.0/16, fe80::/10
Private: true, // Trust 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7
UnixSocket: true, // Trust Unix domain socket connections
},
```
### Complete Example with Nginx
```nginx title="nginx.conf"
server {
listen 443 ssl;
http2 on;
server_name example.com;
ssl_certificate /etc/ssl/certs/example.crt;
ssl_certificate_key /etc/ssl/private/example.key;
location / {
proxy_pass http://127.0.0.1:3000;
proxy_http_version 1.1;
proxy_set_header Connection "";
proxy_set_header Host $host;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
}
```
```go title="main.go"
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New(fiber.Config{
TrustProxy: true,
ProxyHeader: fiber.HeaderXForwardedFor,
EnableIPValidation: true,
TrustProxyConfig: fiber.TrustProxyConfig{
// Trust localhost since Nginx is on the same machine
Loopback: true,
},
})
app.Get("/", func(c fiber.Ctx) error {
// This will now return the real client IP from X-Forwarded-For
// instead of 127.0.0.1
return c.SendString("Your IP: " + c.IP())
})
log.Fatal(app.Listen(":3000"))
}
```
### Testing Your Configuration
You can verify your configuration is working:
```go
app.Get("/debug", func(c fiber.Ctx) error {
return c.JSON(fiber.Map{
"c.IP()": c.IP(), // Should show real client IP
"X-Forwarded-For": c.Get("X-Forwarded-For"), // Raw header value
"IsProxyTrusted": c.IsProxyTrusted(), // Should be true
"RemoteIP": c.RequestCtx().RemoteIP().String(), // Proxy IP
})
})
```
## Enabling HTTP/2
Popular choices include Nginx and Traefik.
Nginx Example
See the [Complete Example with Nginx](#complete-example-with-nginx) above for a full configuration with HTTP/2 enabled.
Traefik Example
```yaml title="traefik.yaml"
entryPoints:
websecure:
address: ":443"
http:
routers:
app:
rule: "Host(`example.com`)"
entryPoints:
- websecure
service: app
tls: {}
services:
app:
loadBalancer:
servers:
- url: "http://127.0.0.1:3000"
```
With this configuration, Traefik terminates TLS and serves your app over HTTP/2.
## HTTP/3 (QUIC) Support
Early Hints (103 responses) are defined for HTTP and can be delivered over HTTP/1.1 and HTTP/2/3. In practice, browsers process 103 most reliably over HTTP/2/3. Many reverse proxies also support HTTP/3 (QUIC):
- **Nginx**
- **Traefik**
Enabling HTTP/3 is optional but can provide lower latency and improved performance for clients that support it. If you enable HTTP/3, your Early Hints responses will still work as expected.
For more details, see the official documentation:
- [Nginx QUIC / HTTP/3](https://nginx.org/en/docs/quic.html)
- [Traefik HTTP/3](https://doc.traefik.io/traefik/reference/install-configuration/entrypoints/#http3)
================================================
FILE: docs/guide/routing.md
================================================
---
id: routing
title: 🔌 Routing
description: >-
Routing refers to how an application's endpoints (URIs) respond to client
requests.
sidebar_position: 1
toc_max_heading_level: 4
---
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
import RoutingHandler from './../partials/routing/handler.md';
## Handlers
## Automatic HEAD routes
Fiber automatically registers a `HEAD` route for every `GET` route you add. The generated handler chain mirrors the `GET` chain, so `HEAD` requests reuse middleware, status codes, and headers while the response body is suppressed.
```go title="GET handlers automatically expose HEAD"
app := fiber.New()
app.Get("/users/:id", func(c fiber.Ctx) error {
c.Set("X-User", c.Params("id"))
return c.SendStatus(fiber.StatusOK)
})
// HEAD /users/:id now returns the same headers and status without a body.
```
You can still register dedicated `HEAD` handlers—even with auto-registration enabled—and Fiber replaces the generated route so your implementation wins:
```go title="Override the generated HEAD handler"
app.Head("/users/:id", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusNoContent)
})
```
To opt out globally, start the app with `DisableHeadAutoRegister`:
```go title="Disable automatic HEAD registration"
handler := func(c fiber.Ctx) error {
c.Set("X-User", c.Params("id"))
return c.SendStatus(fiber.StatusOK)
}
app := fiber.New(fiber.Config{DisableHeadAutoRegister: true})
app.Get("/users/:id", handler) // HEAD /users/:id now returns 405 unless you add it manually.
```
Auto-generated `HEAD` routes participate in every router scope, including `Group` hierarchies, mounted sub-apps, parameterized and wildcard paths, and static file helpers. They also appear in route listings such as `app.Stack()` so tooling sees both the `GET` and `HEAD` entries.
## Paths
A route path paired with an HTTP method defines an endpoint. It can be a plain **string** or a **pattern**.
### Examples of route paths based on strings
```go
// This route path will match requests to the root route, "/":
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("root")
})
// This route path will match requests to "/about":
app.Get("/about", func(c fiber.Ctx) error {
return c.SendString("about")
})
// This route path will match requests to "/random.txt":
app.Get("/random.txt", func(c fiber.Ctx) error {
return c.SendString("random.txt")
})
```
As with the Express.js framework, the order in which routes are declared matters.
Routes are evaluated sequentially, so more specific paths should appear before those with variables.
:::info
Place routes with variable parameters after fixed paths to avoid unintended matches.
:::
## Parameters
Route parameters are dynamic segments in a path, either named or unnamed, used to capture values from the URL. Retrieve them with the [Params](../api/ctx.md#params) function using the parameter name or, for unnamed parameters, the wildcard (`*`) or plus (`+`) symbol with an index.
The characters `:`, `+`, and `*` introduce parameters.
Use `*` or `+` to capture segments greedily.
You can define optional parameters by appending `?` to a named segment. The `+` sign is greedy and required, while `*` acts as an optional greedy wildcard.
### Example of defining routes with route parameters
```go
// Parameters
app.Get("/user/:name/books/:title", func(c fiber.Ctx) error {
fmt.Fprintf(c, "%s\n", c.Params("name"))
fmt.Fprintf(c, "%s\n", c.Params("title"))
return nil
})
// Plus - greedy - not optional
app.Get("/user/+", func(c fiber.Ctx) error {
return c.SendString(c.Params("+"))
})
// Optional parameter
app.Get("/user/:name?", func(c fiber.Ctx) error {
return c.SendString(c.Params("name"))
})
// Wildcard - greedy - optional
app.Get("/user/*", func(c fiber.Ctx) error {
return c.SendString(c.Params("*"))
})
// This route path will match requests to "/v1/some/resource/name:customVerb", since the parameter character is escaped
app.Get(`/v1/some/resource/name\:customVerb`, func(c fiber.Ctx) error {
return c.SendString("Hello, Community")
})
```
:::info
The hyphen \(`-`\) and dot \(`.`\) are treated literally, so you can combine them with route parameters.
:::
:::info
Escape special parameter characters with `\\` to treat them literally. This technique is useful for custom methods like those in the [Google API Design Guide](https://cloud.google.com/apis/design/custom_methods). Wrap routes in backticks to keep escape sequences clear.
:::
```go
// http://localhost:3000/plantae/prunus.persica
app.Get("/plantae/:genus.:species", func(c fiber.Ctx) error {
fmt.Fprintf(c, "%s.%s\n", c.Params("genus"), c.Params("species"))
return nil // prunus.persica
})
```
```go
// http://localhost:3000/flights/LAX-SFO
app.Get("/flights/:from-:to", func(c fiber.Ctx) error {
fmt.Fprintf(c, "%s-%s\n", c.Params("from"), c.Params("to"))
return nil // LAX-SFO
})
```
Fiber's router detects when these characters belong to the literal path and handles them accordingly.
```go
// http://localhost:3000/shop/product/color:blue/size:xs
app.Get("/shop/product/color::color/size::size", func(c fiber.Ctx) error {
fmt.Fprintf(c, "%s:%s\n", c.Params("color"), c.Params("size"))
return nil // blue:xs
})
```
You can chain multiple named or unnamed parameters—including wildcard and plus segments—giving the router greater flexibility.
```go
// GET /@v1
// Params: "sign" -> "@", "param" -> "v1"
app.Get("/:sign:param", handler)
// GET /api-v1
// Params: "name" -> "v1"
app.Get("/api-:name", handler)
// GET /customer/v1/cart/proxy
// Params: "*1" -> "customer/", "*2" -> "/cart"
app.Get("/*v1*/proxy", handler)
// GET /v1/brand/4/shop/blue/xs
// Params: "*1" -> "brand/4", "*2" -> "blue/xs"
app.Get("/v1/*/shop/*", handler)
```
Fiber's routing is inspired by Express but intentionally omits regular expression routes due to their performance cost. You can try similar patterns using the Express route tester (v0.1.7).
### Constraints
Route constraints execute when a match has occurred to the incoming URL and the URL path is tokenized into route values by parameters. The feature was introduced in `v2.37.0` and inspired by [.NET Core](https://docs.microsoft.com/en-us/aspnet/core/fundamentals/routing?view=aspnetcore-6.0#route-constraints).
:::caution
Constraints aren't validation for parameters. If constraints aren't valid for a parameter value, Fiber returns **404 handler**.
:::
| Constraint | Example | Example matches |
| ----------------- | -------------------------------- | ------------------------------------------------------------------------------------------- |
| int | `:id` | 123456789, -123456789 |
| bool | `:active` | true,false |
| guid | `:id` | CD2C1638-1638-72D5-1638-DEADBEEF1638 |
| float | `:weight` | 1.234, -1,001.01e8 |
| minLen(value) | `:username` | Test (must be at least 4 characters) |
| maxLen(value) | `:filename` | MyFile (must be no more than 8 characters |
| len(length) | `:filename` | somefile.txt (exactly 12 characters) |
| min(value) | `:age` | 19 (Integer value must be at least 18) |
| max(value) | `:age` | 91 (Integer value must be no more than 120) |
| range(min,max) | `:age` | 91 (Integer value must be at least 18 but no more than 120) |
| alpha | `:name` | Rick (String must consist of one or more alphabetical characters, a-z and case-insensitive) |
| datetime | `:dob` | 2005-11-01 |
| regex(expression) | `:date` | 2022-08-27 (Must match regular expression) |
#### Examples
```go
app.Get("/:test", func(c fiber.Ctx) error {
return c.SendString(c.Params("test"))
})
// curl -X GET http://localhost:3000/12
// 12
// curl -X GET http://localhost:3000/1
// Not Found
```
You can use `;` for multiple constraints.
```go
app.Get("/:test", func(c fiber.Ctx) error {
return c.SendString(c.Params("test"))
})
// curl -X GET http://localhost:3000/120000
// Not Found
// curl -X GET http://localhost:3000/1
// Not Found
// curl -X GET http://localhost:3000/250
// 250
```
Fiber precompiles the regex when registering routes, so regex constraints add no runtime overhead.
```go
app.Get(`/:date`, func(c fiber.Ctx) error {
return c.SendString(c.Params("date"))
})
// curl -X GET http://localhost:3000/125
// Not Found
// curl -X GET http://localhost:3000/test
// Not Found
// curl -X GET http://localhost:3000/2022-08-27
// 2022-08-27
```
:::caution
Prefix routing characters with `\\` when using the datetime constraint (`*`, `+`, `?`, `:`, `/`, `<`, `>`, `;`, `(`, `)`), to avoid misparsing.
:::
#### Optional Parameter Example
You can impose constraints on optional parameters as well.
```go
app.Get("/:test?", func(c fiber.Ctx) error {
return c.SendString(c.Params("test"))
})
// curl -X GET http://localhost:3000/42
// 42
// curl -X GET http://localhost:3000/
//
// curl -X GET http://localhost:3000/7.0
// Not Found
```
#### Custom Constraint
Custom constraints can be added to Fiber using the `app.RegisterCustomConstraint` method. Your constraints have to be compatible with the `CustomConstraint` interface.
:::caution
Attention, custom constraints can now override built-in constraints. If a custom constraint has the same name as a built-in constraint, the custom constraint will be used instead. This allows for more flexibility in defining route parameter constraints.
:::
Add external constraints when you need stricter rules, such as verifying that a parameter is a valid ULID.
```go
// CustomConstraint is an interface for custom constraints
type CustomConstraint interface {
// Name returns the name of the constraint.
// This name is used in the constraint matching.
Name() string
// Execute executes the constraint.
// It returns true if the constraint is matched and right.
// param is the parameter value to check.
// args are the constraint arguments.
Execute(param string, args ...string) bool
}
```
You can check the example below:
```go
type UlidConstraint struct {
fiber.CustomConstraint
}
func (*UlidConstraint) Name() string {
return "ulid"
}
func (*UlidConstraint) Execute(param string, args ...string) bool {
_, err := ulid.Parse(param)
return err == nil
}
func main() {
app := fiber.New()
app.RegisterCustomConstraint(&UlidConstraint{})
app.Get("/login/:id", func(c fiber.Ctx) error {
return c.SendString("...")
})
app.Listen(":3000")
// /login/01HK7H9ZE5BFMK348CPYP14S0Z -> 200
// /login/12345 -> 404
}
```
## Middleware
Functions that are designed to make changes to the request or response are called **middleware functions**. The [Next](../api/ctx.md#next) is a **Fiber** router function, when called, executes the **next** function that **matches** the current route.
### Example of a middleware function
```go
app.Use(func(c fiber.Ctx) error {
// Set a custom header on all responses:
c.Set("X-Custom-Header", "Hello, World")
// Go to next middleware:
return c.Next()
})
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
```
`Use` method path is a **mount**, or **prefix** path, and limits middleware to only apply to any paths requested that begin with it.
:::note
Prefix matches must now end at a slash boundary (or be an exact match). For example, `/api` runs for `/api` and `/api/users` but no longer for `/apiv2`. Parameter tokens such as `:name`, `:name?`, `*`, and `+` are still expanded before this boundary check runs.
:::
### Constraints on Adding Routes Dynamically
:::caution
Adding routes dynamically after the application has started is not supported due to design and performance considerations. Make sure to define all your routes before the application starts.
:::
## Grouping
If you have many endpoints, you can organize your routes using `Group`.
```go
func main() {
app := fiber.New()
api := app.Group("/api", middleware) // /api
v1 := api.Group("/v1", middleware) // /api/v1
v1.Get("/list", handler) // /api/v1/list
v1.Get("/user", handler) // /api/v1/user
v2 := api.Group("/v2", middleware) // /api/v2
v2.Get("/list", handler) // /api/v2/list
v2.Get("/user", handler) // /api/v2/user
log.Fatal(app.Listen(":3000"))
}
```
More information about this in our [Grouping Guide](./grouping.md)
================================================
FILE: docs/guide/templates.md
================================================
---
id: templates
title: 📝 Templates
description: Fiber supports server-side template engines.
sidebar_position: 3
---
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
Templates render dynamic content without requiring a separate frontend framework.
## Template Engines
Fiber accepts a custom template engine during app initialization.
```go
app := fiber.New(fiber.Config{
// Provide a template engine
Views: engine,
// Default path for views, overridden when calling Render()
ViewsLayout: "layouts/main",
// Enables/Disables access to `ctx.Locals()` entries in rendered views
// (defaults to false)
PassLocalsToViews: false,
})
```
### Supported Engines
Fiber maintains a [templates](https://docs.gofiber.io/template) package that wraps several engines:
* [ace](https://docs.gofiber.io/template/ace/)
* [amber](https://docs.gofiber.io/template/amber/)
* [django](https://docs.gofiber.io/template/django/)
* [handlebars](https://docs.gofiber.io/template/handlebars)
* [html](https://docs.gofiber.io/template/html)
* [jet](https://docs.gofiber.io/template/jet)
* [mustache](https://docs.gofiber.io/template/mustache)
* [pug](https://docs.gofiber.io/template/pug)
* [slim](https://docs.gofiber.io/template/slim)
:::info
Custom engines implement the `Views` interface to work with Fiber.
:::
```go title="Views interface"
type Views interface {
// Fiber executes Load() on app initialization to load/parse the templates
Load() error
// Outputs a template to the provided buffer using the provided template,
// template name, and bound data
Render(io.Writer, string, interface{}, ...string) error
}
```
:::note
The `Render` method powers [**ctx.Render\(\)**](../api/ctx.md#render), which accepts a template name and data to bind.
:::
## Rendering Templates
After configuring an engine, handlers call [**ctx.Render\(\)**](../api/ctx.md#render) with a template name and data to send the rendered output.
```go title="Signature"
func (c Ctx) Render(name string, bind Map, layouts ...string) error
```
:::info
By default, [**ctx.Render\(\)**](../api/ctx.md#render) searches for the template in the `ViewsLayout` path. Pass alternate paths in the `layouts` argument to override this behavior.
:::
```go
app.Get("/", func(c fiber.Ctx) error {
return c.Render("index", fiber.Map{
"Title": "Hello, World!",
})
})
```
```html
{{.Title}}
```
:::caution
When `PassLocalsToViews` is enabled, all values set using `ctx.Locals(key, value)` are passed to the template. Use unique keys to avoid collisions.
:::
## Advanced Templating
### Custom Functions
Fiber supports adding custom functions to templates.
#### AddFunc
Adds a global function to all templates.
```go title="Signature"
func (e *Engine) AddFunc(name string, fn interface{}) IEngineCore
```
```go
// Add `ToUpper` to engine
engine := html.New("./views", ".html")
engine.AddFunc("ToUpper", func(s string) string {
return strings.ToUpper(s)
}
// Initialize Fiber App
app := fiber.New(fiber.Config{
Views: engine,
})
app.Get("/", func (c fiber.Ctx) error {
return c.Render("index", fiber.Map{
"Content": "hello, World!"
})
})
```
```html
This will be in {{ToUpper "all caps"}}:
{{ToUpper .Content}}
```
#### AddFuncMap
Adds a Map of functions (keyed by name) to all templates.
```go title="Signature"
func (e *Engine) AddFuncMap(m map[string]interface{}) IEngineCore
```
```go
// Add `ToUpper` to engine
engine := html.New("./views", ".html")
engine.AddFuncMap(map[string]interface{}{
"ToUpper": func(s string) string {
return strings.ToUpper(s)
},
})
// Initialize Fiber App
app := fiber.New(fiber.Config{
Views: engine,
})
app.Get("/", func (c fiber.Ctx) error {
return c.Render("index", fiber.Map{
"Content": "hello, world!"
})
})
```
```html
This will be in {{ToUpper "all caps"}}:
{{ToUpper .Content}}
```
* For more advanced template documentation, please visit the [gofiber/template GitHub Repository](https://github.com/gofiber/template).
## Full Example
```go
package main
import (
"log"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/template/html/v2"
)
func main() {
// Initialize standard Go html template engine
engine := html.New("./views", ".html")
// If you want to use another engine,
// just replace with following:
// Create a new engine with django
// engine := django.New("./views", ".django")
app := fiber.New(fiber.Config{
Views: engine,
})
app.Get("/", func(c fiber.Ctx) error {
// Render index template
return c.Render("index", fiber.Map{
"Title": "Go Fiber Template Example",
"Description": "An example template",
"Greeting": "Hello, World!",
});
})
log.Fatal(app.Listen(":3000"))
}
```
```html
{{.Title}}
{{.Title}}
{{.Greeting}}
```
================================================
FILE: docs/guide/utils.md
================================================
---
id: utils
title: 🧰 Utils
sidebar_position: 8
toc_max_heading_level: 4
---
## Generics
### Convert
Converts a string to a specific type while handling errors and optional defaults.
It wraps conversion and fallback logic to keep your code clean and consistent.
```go title="Signature"
func Convert[T any](value string, converter func(string) (T, error), defaultValue ...T) (T, error)
```
```go title="Example"
// GET http://example.com/id/bb70ab33-d455-4a03-8d78-d3c1dacae9ff
app.Get("/id/:id", func(c fiber.Ctx) error {
fiber.Convert(c.Params("id"), uuid.Parse) // UUID(bb70ab33-d455-4a03-8d78-d3c1dacae9ff), nil
})
// GET http://example.com/search?id=65f6f54221fb90e6a6b76db7
app.Get("/search", func(c fiber.Ctx) error {
fiber.Convert(c.Query("id"), mongo.ParseObjectID) // objectid(65f6f54221fb90e6a6b76db7), nil
fiber.Convert(c.Query("id"), uuid.Parse) // uuid.Nil, error(cannot parse given uuid)
fiber.Convert(c.Query("id"), uuid.Parse, mongo.NewObjectID) // new object id generated and return nil as error.
return nil
})
// ...
```
### GetReqHeader
Retrieves an HTTP request header as a specific type using generics.
```go title="Signature"
func GetReqHeader[V GenericType](c Ctx, key string, defaultValue ...V) V
```
```go title="Example"
app.Get("/search", func(c fiber.Ctx) error {
// curl -X GET http://example.com/search -H "X-Request-ID: 12345" -H "X-Request-Name: John"
fiber.GetReqHeader[int](c, "X-Request-ID") // => returns 12345 as integer.
fiber.GetReqHeader[string](c, "X-Request-Name") // => returns "John" as string.
fiber.GetReqHeader[string](c, "unknownParam", "default") // => returns "default" as string.
// ...
})
```
### Locals
Reads or writes local values in the request context using generics.
```go title="Signature"
// Set a value
func Locals[V any](c Ctx, key any, value ...V) V
// Get a value
func Locals[V any](c Ctx, key any) V
```
```go title="Example"
app.Use("/user/:user/:id", func(c fiber.Ctx) error {
// set local values
fiber.Locals[string](c, "user", "john")
fiber.Locals[int](c, "id", 25)
// ...
return c.Next()
})
app.Get("/user/*", func(c fiber.Ctx) error {
// get local values
name := fiber.Locals[string](c, "user") // john
age := fiber.Locals[int](c, "id") // 25
// ...
})
```
### Params
Retrieves route parameters as a specific type.
```go title="Signature"
func Params[V GenericType](c Ctx, key string, defaultValue ...V) V
```
```go title="Example"
app.Get("/user/:user/:id", func(c fiber.Ctx) error {
// http://example.com/user/john/25
fiber.Params[int](c, "id") // => returns 25 as integer.
fiber.Params[int](c, "unknownParam", 99) // => returns the default 99 as integer.
// ...
return c.SendString("Hello, " + fiber.Params[string](c, "user"))
})
```
### Query
Retrieves query parameters as a specific type.
```go title="Signature"
func Query[V GenericType](c Ctx, key string, defaultValue ...V) V
```
```go title="Example"
app.Get("/search", func(c fiber.Ctx) error {
// http://example.com/search?name=john&age=25
fiber.Query[string](c, "name") // => returns "john"
fiber.Query[int](c, "age") // => returns 25 as integer.
fiber.Query[string](c, "unknownParam", "default") // => returns "default" as string.
// ...
})
```
### RoutePatternMatch
Checks whether a given path matches a Fiber route pattern. Useful for testing
patterns without registering them. Patterns may contain parameters, wildcards
and optional segments. An optional `Config` allows control over case sensitivity
and strict routing.
```go title="Signature"
func RoutePatternMatch(path, pattern string, cfg ...Config) bool
```
```go title="Example"
fiber.RoutePatternMatch("/user/john", "/user/:name") // true
fiber.RoutePatternMatch(
"/User/john",
"/user/:name",
fiber.Config{CaseSensitive: true},
) // false
```
================================================
FILE: docs/guide/validation.md
================================================
---
id: validation
title: 🔎 Validation
sidebar_position: 5
---
## Validator package
Fiber's [Bind](../api/bind.md#validation) function binds request data to a struct and validates it.
```go title="Basic Example"
import "github.com/go-playground/validator/v10"
type structValidator struct {
validate *validator.Validate
}
// Validator needs to implement the Validate method
func (v *structValidator) Validate(out any) error {
return v.validate.Struct(out)
}
// Set up your validator in the config
app := fiber.New(fiber.Config{
StructValidator: &structValidator{validate: validator.New()},
})
// Note:
// StructValidator runs only for struct destinations (or pointers to structs).
// Binding into maps and other non-struct types skips validation.
type User struct {
Name string `json:"name" form:"name" query:"name" validate:"required"`
Age int `json:"age" form:"age" query:"age" validate:"gte=0,lte=100"`
}
app.Post("/", func(c fiber.Ctx) error {
user := new(User)
// Works with all bind methods—Body, Query, Form, etc.
if err := c.Bind().Body(user); err != nil { // validation errors are returned here
return err
}
return c.JSON(user)
})
```
```go title="Advanced Validation Example"
type User struct {
Name string `json:"name" validate:"required,min=3,max=32"`
Email string `json:"email" validate:"required,email"`
Age int `json:"age" validate:"gte=0,lte=100"`
Password string `json:"password" validate:"required,min=8"`
Website string `json:"website" validate:"url"`
}
// Custom validation error messages
type UserWithCustomMessages struct {
Name string `json:"name" validate:"required,min=3,max=32" message:"Name is required and must be between 3 and 32 characters"`
Email string `json:"email" validate:"required,email" message:"Valid email is required"`
Age int `json:"age" validate:"gte=0,lte=100" message:"Age must be between 0 and 100"`
}
app.Post("/user", func(c fiber.Ctx) error {
user := new(User)
if err := c.Bind().Body(user); err != nil {
// Handle validation errors
if validationErrors, ok := err.(validator.ValidationErrors); ok {
for _, e := range validationErrors {
// e.Field() - field name
// e.Tag() - validation tag
// e.Value() - invalid value
// e.Param() - validation parameter
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"field": e.Field(),
"error": e.Error(),
})
}
}
return err
}
return c.JSON(user)
})
```
```go title="Custom Validator Example"
// Custom validator for password strength
type PasswordValidator struct {
validate *validator.Validate
}
func (v *PasswordValidator) Validate(out any) error {
if err := v.validate.Struct(out); err != nil {
return err
}
// Custom password validation logic
if user, ok := out.(*User); ok {
if len(user.Password) < 8 {
return errors.New("password must be at least 8 characters")
}
// Add more password validation rules here
}
return nil
}
// Usage
app := fiber.New(fiber.Config{
StructValidator: &PasswordValidator{validate: validator.New()},
})
```
================================================
FILE: docs/intro.md
================================================
---
slug: /
id: welcome
title: 👋 Welcome
sidebar_position: 1
---
Welcome to Fiber's online API documentation, complete with examples to help you start building web applications right away!
**Fiber** is an [Express](https://github.com/expressjs/express)-inspired **web framework** built on top of [Fasthttp](https://github.com/valyala/fasthttp), the **fastest** HTTP engine for [Go](https://go.dev/doc/). It is designed to facilitate rapid development with **zero memory allocations** and a strong focus on **performance**.
These docs cover **Fiber v3**.
Looking to practice Fiber concepts hands-on? Check out our [Learning Resources](./extra/learning-resources) for interactive challenges and tutorials.
### Installation
First, [download](https://go.dev/dl/) and install Go. Version `1.25` or higher is required.
Install Fiber using the [`go get`](https://pkg.go.dev/cmd/go/#hdr-Add_dependencies_to_current_module_and_install_them) command:
```bash
go get github.com/gofiber/fiber/v3
```
### Zero Allocation
Fiber is optimized for **high performance**, meaning values returned from **fiber.Ctx** are **not** immutable by default and **will** be reused across requests. As a rule of thumb, you should use context values only within the handler and **must not** keep any references. Once you return from the handler, any values obtained from the context will be reused in future requests. Here is an example:
```go
func handler(c fiber.Ctx) error {
// Variable is only valid within this handler
result := c.Params("foo")
// ...
}
```
If you need to persist such values outside the handler, make copies of their **underlying buffer** using the [copy](https://pkg.go.dev/builtin/#copy) builtin. Here is an example of persisting a string:
```go
func handler(c fiber.Ctx) error {
// Variable is only valid within this handler
result := c.Params("foo")
// Make a copy
buffer := make([]byte, len(result))
copy(buffer, result)
resultCopy := string(buffer)
// Variable is now valid indefinitely
// ...
}
```
Fiber provides `GetString` and `GetBytes` methods on the app that detach values when `Immutable` is enabled and the data isn't already read-only. If it's disabled, use `utils.CopyString` and `utils.CopyBytes` to allocate only when necessary.
```go
app.Get("/:foo", func(c fiber.Ctx) error {
// Detach if necessary when Immutable is enabled
result := c.App().GetString(c.Params("foo"))
// ...
})
```
Alternatively, you can enable the `Immutable` setting. This makes all values returned from the context immutable, allowing you to persist them anywhere. Note that this comes at the cost of performance.
```go
app := fiber.New(fiber.Config{
Immutable: true,
})
```
For more information, please refer to [#426](https://github.com/gofiber/fiber/issues/426), [#185](https://github.com/gofiber/fiber/issues/185), and [#3012](https://github.com/gofiber/fiber/issues/3012).
### Hello, World
Here is the simplest **Fiber** application you can create:
```go
package main
import "github.com/gofiber/fiber/v3"
func main() {
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
app.Listen(":3000")
}
```
```bash
go run server.go
```
Browse to `http://localhost:3000` and you should see `Hello, World!` displayed on the page.
### Basic Routing
Routing determines how an application responds to a client request at a particular endpoint—a combination of path and HTTP request method (`GET`, `PUT`, `POST`, etc.).
Each route can have **multiple handler functions** that are executed when the route is matched.
Route definitions follow the structure below:
```go
// Function signature
app.Method(path string, ...func(fiber.Ctx) error)
```
- `app` is an instance of **Fiber**
- `Method` is an [HTTP request method](./api/app#route-handlers): `GET`, `PUT`, `POST`, etc.
- `path` is a virtual path on the server
- `func(fiber.Ctx) error` is a callback function containing the [Context](./api/ctx) executed when the route is matched
#### Simple Route
```go
// Respond with "Hello, World!" on root path "/"
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
```
#### Parameters
```go
// GET http://localhost:8080/hello%20world
app.Get("/:value", func(c fiber.Ctx) error {
return c.SendString("value: " + c.Params("value"))
// => Response: "value: hello world"
})
```
#### Optional Parameter
```go
// GET http://localhost:3000/john
app.Get("/:name?", func(c fiber.Ctx) error {
if c.Params("name") != "" {
return c.SendString("Hello " + c.Params("name"))
// => Response: "Hello john"
}
return c.SendString("Where is john?")
// => Response: "Where is john?"
})
```
#### Wildcards
```go
// GET http://localhost:3000/api/user/john
app.Get("/api/*", func(c fiber.Ctx) error {
return c.SendString("API path: " + c.Params("*"))
// => Response: "API path: user/john"
})
```
### Static Files
To serve static files such as **images**, **CSS**, and **JavaScript** files, use the [static middleware](./middleware/static.md).
Use the following code to serve files in a directory named `./public`:
```go
package main
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/static"
)
func main() {
app := fiber.New()
app.Use("/", static.New("./public"))
app.Listen(":3000")
}
```
Now, you can access the files in the `./public` directory via your browser:
```bash
http://localhost:3000/hello.html
http://localhost:3000/js/jquery.js
http://localhost:3000/css/style.css
```
================================================
FILE: docs/middleware/_category_.json
================================================
{
"label": "\uD83E\uDDEC Middleware",
"position": 4,
"collapsed": true,
"link": {
"type": "generated-index",
"description": "Middleware is a function chained in the HTTP request cycle with access to the Context which it uses to perform a specific action, for example, logging every request or enabling CORS."
}
}
================================================
FILE: docs/middleware/adaptor.md
================================================
---
id: adaptor
---
# Adaptor
The `adaptor` package converts between Fiber and `net/http`, letting you reuse handlers, middleware, and requests across both frameworks.
:::tip
Fiber can register plain `net/http` handlers directly—just pass an `http.Handler`,
`http.HandlerFunc`, or `func(http.ResponseWriter, *http.Request)` to any router
method and it will be adapted automatically. The adaptor helpers remain valuable
when you need to convert middleware, swap handler directions, or transform
requests explicitly.
:::
:::caution Fiber features are unavailable
Even when you register them directly, adapted `net/http` handlers still run with standard
library semantics. They don't have access to `fiber.Ctx`, and the compatibility layer comes
with additional overhead compared to native Fiber handlers. Use them for interop and legacy
scenarios, but prefer Fiber handlers when performance or Fiber-specific APIs matter.
:::
## Features
- Convert `net/http` handlers and middleware to Fiber handlers
- Convert Fiber handlers to `net/http` handlers
- Convert a Fiber context (`fiber.Ctx`) into an `http.Request`
- Copy values stored in a `context.Context` onto a `fasthttp.RequestCtx`
:::note Body size limits when running Fiber from net/http
When Fiber is executed from a `net/http` server through `FiberHandler`, `FiberHandlerFunc`,
or `FiberApp`, the adaptor enforces the app's configured `BodyLimit`. The app's `BodyLimit` defaults to **4 MiB** if a non-positive value is provided during configuration. Requests exceeding the active limit receive `413 Request Entity Too Large`.
:::
## API Reference
| Name | Signature | Description |
|-------------------------------|-------------------------------------------------------------------------------|-------------------------------------------------------------------------------|
| `HTTPHandler` | `HTTPHandler(h http.Handler) fiber.Handler` | Converts `http.Handler` to `fiber.Handler` |
| `HTTPHandlerWithContext` | `HTTPHandlerWithContext(h http.Handler) fiber.Handler` | Converts `http.Handler` to `fiber.Handler`, propagating Fiber's local context |
| `HTTPHandlerFunc` | `HTTPHandlerFunc(h http.HandlerFunc) fiber.Handler` | Converts `http.HandlerFunc` to `fiber.Handler` |
| `HTTPMiddleware` | `HTTPMiddleware(mw func(http.Handler) http.Handler) fiber.Handler` | Converts `http.Handler` middleware to `fiber.Handler` middleware |
| `FiberHandler` | `FiberHandler(h fiber.Handler) http.Handler` | Converts `fiber.Handler` to `http.Handler` |
| `FiberHandlerFunc` | `FiberHandlerFunc(h fiber.Handler) http.HandlerFunc` | Converts `fiber.Handler` to `http.HandlerFunc` |
| `FiberApp` | `FiberApp(app *fiber.App) http.HandlerFunc` | Converts an entire Fiber app to a `http.HandlerFunc` |
| `ConvertRequest` | `ConvertRequest(c fiber.Ctx, forServer bool) (*http.Request, error)` | Converts `fiber.Ctx` into a `http.Request` |
| `LocalContextFromHTTPRequest` | `LocalContextFromHTTPRequest(r *http.Request) (context.Context, bool)` | Extracts the propagated `context.Context` from an adapted `http.Request` |
| `CopyContextToFiberContext` | `CopyContextToFiberContext(context any, requestContext *fasthttp.RequestCtx)` | Copies `context.Context` to `fasthttp.RequestCtx` |
---
## Usage Examples
### 1. Using `net/http` handlers in Fiber (`HTTPHandler`, `HTTPHandlerFunc`)
Run standard `net/http` handlers inside Fiber. Fiber can auto-adapt them, or you can
explicitly convert them when you want to cache or share the converted handler.
```go
package main
import (
"fmt"
"net/http"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
func main() {
app := fiber.New()
// Fiber adapts net/http handlers for you during registration.
app.Get("/", http.HandlerFunc(helloHandler))
// You can also convert and reuse the handler manually.
cached := adaptor.HTTPHandler(http.HandlerFunc(helloHandler))
app.Get("/cached", cached)
// When you already have an http.HandlerFunc, convert it directly.
app.Get("/func", adaptor.HTTPHandlerFunc(helloHandler))
app.Listen(":3000")
}
func helloHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Hello from net/http!")
}
```
### 2. Using `net/http` middleware with Fiber (`HTTPMiddleware`)
Middleware written for `net/http` can run inside Fiber:
```go
package main
import (
"log"
"net/http"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
func main() {
app := fiber.New()
// Apply an http middleware in Fiber
app.Use(adaptor.HTTPMiddleware(loggingMiddleware))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello Fiber!")
})
app.Listen(":3000")
}
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("Request received")
next.ServeHTTP(w, r)
})
}
```
### 3. Using Fiber handlers in `net/http` (`FiberHandler`)
You can use Fiber handlers from `net/http`:
```go
package main
import (
"net/http"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
func main() {
// Convert a Fiber handler to an http.Handler
http.Handle("/", adaptor.FiberHandler(helloFiber))
// Convert a Fiber handler to an http.HandlerFunc
http.HandleFunc("/func", adaptor.FiberHandlerFunc(helloFiber))
http.ListenAndServe(":3000", nil)
}
func helloFiber(c fiber.Ctx) error {
return c.SendString("Hello from Fiber!")
}
```
### 4. Converting Fiber handlers to `http.HandlerFunc` (`FiberHandlerFunc`)
When you specifically need an `http.HandlerFunc`, wrap the Fiber handler directly:
```go
package main
import (
"net/http"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
func main() {
http.HandleFunc("/func-only", adaptor.FiberHandlerFunc(helloFiber))
http.ListenAndServe(":3000", nil)
}
func helloFiber(c fiber.Ctx) error {
return c.SendString("Hello from Fiber!")
}
```
### 5. Running a full Fiber app inside `net/http` (`FiberApp`)
You can wrap a full Fiber app inside `net/http`:
```go
package main
import (
"net/http"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
func main() {
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello from Fiber!")
})
// Run Fiber inside an http server
http.ListenAndServe(":3000", adaptor.FiberApp(app))
}
```
### 6. Converting `fiber.Ctx` to `*http.Request` (`ConvertRequest`)
Create an `*http.Request` from a `fiber.Ctx`. The `forServer` parameter determines how
server-oriented fields are populated:
- Use `forServer = true` when the converted request will be passed into a `net/http` handler
(sets `RequestURI`, `RemoteAddr`, and `TLS` fields for server-side handling)
- Use `forServer = false` when creating a request for client-side use (e.g., making an
outbound HTTP request with `http.Client`)
```go
package main
import (
"net/http"
"net/http/httptest"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
func main() {
app := fiber.New()
app.Get("/request", handleRequest)
app.Listen(":3000")
}
func handleRequest(c fiber.Ctx) error {
// Use forServer = true when passing to a net/http handler
httpReq, err := adaptor.ConvertRequest(c, true)
if err != nil {
return err
}
// Pass the request to a net/http handler.
recorder := httptest.NewRecorder()
http.DefaultServeMux.ServeHTTP(recorder, httpReq)
return c.SendString("Converted Request URL: " + httpReq.URL.String())
}
```
### 7. Passing Fiber user context into `net/http`
This example shows a realistic flow: a Fiber middleware sets a request-scoped `context.Context` (with a `request_id`) on the Fiber context, then an adapted `net/http` handler retrieves it via `LocalContextFromHTTPRequest`.
```go
package main
import (
"context"
"fmt"
"net/http"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
type ctxKey string
const requestIDKey ctxKey = "request_id"
func main() {
app := fiber.New()
// Create a request-scoped context in Fiber (e.g., request id, auth claims, trace span).
app.Use(func(c fiber.Ctx) error {
reqID := c.Get("X-Request-ID")
ctx := context.WithValue(context.Background(), requestIDKey, reqID)
// Fiber stores request-scoped context as "user context".
c.SetContext(ctx)
return c.Next()
})
// 2) Run a standard net/http handler that includes Fiber's user context propagated.
app.Get("/hello", adaptor.HTTPHandlerWithContext(http.HandlerFunc(handleRequest)))
app.Listen(":3000")
}
func handleRequest(w http.ResponseWriter, r *http.Request) {
ctx, ok := adaptor.LocalContextFromHTTPRequest(r)
if !ok || ctx == nil {
http.Error(w, "missing propagated context", http.StatusInternalServerError)
return
}
reqID, _ := ctx.Value(requestIDKey).(string)
fmt.Fprintf(w, "Hello from net/http (request_id=%s)\n", reqID)
}
```
### 8. Copying context values onto `fasthttp.RequestCtx` (`CopyContextToFiberContext`)
`CopyContextToFiberContext` copies values stored in a `context.Context` onto a
`fasthttp.RequestCtx`. The function is marked deprecated in code because it uses
reflection and unsafe operations—prefer explicit parameter passing when possible.
When you do need it, call it immediately after you add values to the `net/http`
context so Fiber can read them via `c.Context()`:
```go
package main
import (
"context"
"net/http"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/adaptor"
)
type contextKey string
func main() {
app := fiber.New()
app.Use(func(c fiber.Ctx) error {
// Convert the Fiber context to an http.Request so we can attach context values.
httpReq, err := adaptor.ConvertRequest(c, true)
if err != nil {
return err
}
// Add context data and push it back to the Fiber context.
enriched := httpReq.WithContext(context.WithValue(httpReq.Context(), contextKey("requestID"), "req-123"))
adaptor.CopyContextToFiberContext(enriched.Context(), c.RequestCtx())
return c.Next()
})
app.Get("/", func(c fiber.Ctx) error {
if id, ok := c.Context().Value(contextKey("requestID")).(string); ok {
return c.SendString("Request ID: " + id)
}
return c.SendStatus(fiber.StatusNotFound)
})
app.Listen(":3000")
}
```
---
## Summary
The `adaptor` package lets Fiber and `net/http` interoperate so you can:
- Convert handlers and middleware in both directions
- Run Fiber apps inside `net/http`
- Convert `fiber.Ctx` to `http.Request`
- Propagate Fiber's user context into adapted `net/http` handlers
This makes it straightforward to integrate Fiber with existing Go projects or migrate between frameworks.
================================================
FILE: docs/middleware/basicauth.md
================================================
---
id: basicauth
---
# BasicAuth
Basic Authentication middleware for [Fiber](https://github.com/gofiber/fiber) that provides HTTP basic auth. It calls the next handler for valid credentials and returns [`401 Unauthorized`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/401) for missing or invalid credentials, [`400 Bad Request`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/400) for malformed `Authorization` headers, or [`431 Request Header Fields Too Large`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/431) when the header exceeds size limits. Credentials may omit Base64 padding as permitted by RFC 7235's `token68` syntax.
The default unauthorized response includes the header `WWW-Authenticate: Basic realm="Restricted", charset="UTF-8"`, sets `Cache-Control: no-store`, and adds a `Vary: Authorization` header. Only the `UTF-8` charset is supported; any other value will panic.
## Signatures
```go
func New(config Config) fiber.Handler
func UsernameFromContext(ctx any) string
```
`UsernameFromContext` accepts a `fiber.CustomCtx`, `fiber.Ctx`, a `*fasthttp.RequestCtx`, or a `context.Context`.
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/basicauth"
)
```
Once your Fiber app is initialized, choose one of the following approaches:
```go
// Provide a minimal config
app.Use(basicauth.New(basicauth.Config{
Users: map[string]string{
// "doe" hashed using SHA-256
"john": "{SHA256}eZ75KhGvkY4/t0HfQpNPO1aO0tk6wd908bjUGieTKm8=",
// "123456" hashed using bcrypt
"admin": "$2a$10$gTYwCN66/tBRoCr3.TXa1.v1iyvwIF7GRBqxzv7G.AHLMt/owXrp.",
},
}))
// Or extend your config for customization
app.Use(basicauth.New(basicauth.Config{
Users: map[string]string{
// "doe" hashed using SHA-256
"john": "{SHA256}eZ75KhGvkY4/t0HfQpNPO1aO0tk6wd908bjUGieTKm8=",
// "123456" hashed using bcrypt
"admin": "$2a$10$gTYwCN66/tBRoCr3.TXa1.v1iyvwIF7GRBqxzv7G.AHLMt/owXrp.",
},
Realm: "Forbidden",
Authorizer: func(user, pass string, c fiber.Ctx) bool {
// custom validation logic
return (user == "john" || user == "admin")
},
Unauthorized: func(c fiber.Ctx) error {
return c.SendFile("./unauthorized.html")
},
}))
```
### Password hashes
Passwords must be supplied in pre-hashed form. The middleware detects the
hashing algorithm from a prefix:
- `"{SHA512}"` or `"{SHA256}"` followed by a base64-encoded digest
- standard bcrypt strings beginning with `$2`
If no prefix is present, the value is interpreted as a SHA-256 digest encoded in
hex or base64. Plaintext passwords are rejected.
#### Generating SHA-256 and SHA-512 passwords
Create a digest, encode it in base64, and prefix it with `{SHA256}` or
`{SHA512}` before adding it to `Users`:
```bash
# SHA-256
printf 'secret' | openssl dgst -binary -sha256 | base64
# SHA-512
printf 'secret' | openssl dgst -binary -sha512 | base64
```
Include the prefix in your config:
```go
Users: map[string]string{
"john": "{SHA256}K7gNU3sdo+OL0wNhqoVWhr3g6s1xYv72ol/pe/Unols=",
"admin": "{SHA512}vSsar3708Jvp9Szi2NWZZ02Bqp1qRCFpbcTZPdBhnWgs5WtNZKnvCXdhztmeD2cmW192CF5bDufKRpayrW/isg==",
}
```
## Config
| Property | Type | Description | Default |
|:----------------|:----------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` |
| Users | `map[string]string` | Users maps usernames to **hashed** passwords (e.g. bcrypt, `{SHA256}`). | `map[string]string{}` |
| Realm | `string` | Realm is a string to define the realm attribute of BasicAuth. The realm identifies the system to authenticate against and can be used by clients to save credentials. | `"Restricted"` |
| Charset | `string` | Charset sent in the `WWW-Authenticate` header. Only `"UTF-8"` is supported (case-insensitive). | `"UTF-8"` |
| HeaderLimit | `int` | Maximum allowed length of the `Authorization` header. Requests exceeding this limit are rejected. | `8192` |
| Authorizer | `func(string, string, fiber.Ctx) bool` | Authorizer defines a function to check the credentials. It will be called with a username, password, and the current context and is expected to return true or false to indicate approval. | `nil` |
| Unauthorized | `fiber.Handler` | Unauthorized defines the response body for unauthorized responses. | `nil` |
| BadRequest | `fiber.Handler` | BadRequest defines the response for malformed `Authorization` headers. | `nil` |
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
Users: map[string]string{},
Realm: "Restricted",
Charset: "UTF-8",
HeaderLimit: 8192,
Authorizer: nil,
Unauthorized: nil,
BadRequest: nil,
}
```
================================================
FILE: docs/middleware/cache.md
================================================
---
id: cache
---
# Cache
Cache middleware for [Fiber](https://github.com/gofiber/fiber) that intercepts responses and stores the body, `Content-Type`, and status code under a key derived from the request path and method. Special thanks to [@codemicro](https://github.com/codemicro/fiber-cache) for contributing this middleware to Fiber core.
By default, cached responses expire after five minutes and the middleware stores up to 1 MB of response bodies.
Request directives
- `Cache-Control: no-cache` returns the latest response while still caching it, so the status is always `miss`.
- `Cache-Control: no-store` skips caching and always forwards a fresh response.
If the response includes a `Cache-Control: max-age` directive, its value sets the cache entry's expiration.
Cacheable status codes
The middleware caches these RFC 7231 status codes:
- `200: OK`
- `203: Non-Authoritative Information`
- `204: No Content`
- `206: Partial Content`
- `300: Multiple Choices`
- `301: Moved Permanently`
- `404: Not Found`
- `405: Method Not Allowed`
- `410: Gone`
- `414: URI Too Long`
- `501: Not Implemented`
Responses with other status codes result in an `unreachable` cache status.
For more about cacheable status codes and RFC 7231, see:
- [Cacheable - MDN Web Docs](https://developer.mozilla.org/en-US/docs/Glossary/Cacheable)
- [RFC7231 - Hypertext Transfer Protocol (HTTP/1.1): Semantics and Content](https://datatracker.ietf.org/doc/html/rfc7231)
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/cache"
"github.com/gofiber/utils/v2"
)
```
Once your Fiber app is initialized, register the middleware:
```go
// Initialize default config
app.Use(cache.New())
// Or extend the config for customization
app.Use(cache.New(cache.Config{
Next: func(c fiber.Ctx) bool {
return fiber.Query[bool](c, "noCache")
},
Expiration: 30 * time.Minute,
DisableCacheControl: true,
}))
```
Customize the cache key and expiration; the HTTP method is appended automatically:
```go
app.Use(cache.New(cache.Config{
ExpirationGenerator: func(c fiber.Ctx, cfg *cache.Config) time.Duration {
newCacheTime, _ := strconv.Atoi(c.GetRespHeader("Cache-Time", "600"))
return time.Second * time.Duration(newCacheTime)
},
KeyGenerator: func(c fiber.Ctx) string {
return utils.CopyString(c.Path())
},
}))
app.Get("/", func(c fiber.Ctx) error {
c.Response().Header.Add("Cache-Time", "6000")
return c.SendString("hi")
})
```
Use `CacheInvalidator` to invalidate entries programmatically:
```go
app.Use(cache.New(cache.Config{
CacheInvalidator: func(c fiber.Ctx) bool {
return fiber.Query[bool](c, "invalidateCache")
},
}))
```
`CacheInvalidator` defines custom invalidation rules. Return `true` to bypass the cache. In the example above, setting the `invalidateCache` query parameter to `true` invalidates the entry.
Cache keys are masked in logs and error messages by default. Set `DisableValueRedaction` to `true` if you explicitly need the raw key for debugging.
## Config
| Property | Type | Description | Default |
| :------------------- | :--------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------- |
| Next | `func(fiber.Ctx) bool` | Next defines a function that is executed before creating the cache entry and can be used to execute the request without cache creation. If an entry already exists, it will be used. If you want to completely bypass the cache functionality in certain cases, you should use the [skip middleware](skip.md). | `nil` |
| Expiration | `time.Duration` | Expiration is the time that a cached response will live. | `5 * time.Minute` |
| CacheHeader | `string` | CacheHeader is the header on the response header that indicates the cache status, with the possible return values "hit," "miss," or "unreachable." | `X-Cache` |
| DisableCacheControl | `bool` | DisableCacheControl omits the `Cache-Control` header when set to `true`. | `false` |
| CacheInvalidator | `func(fiber.Ctx) bool` | CacheInvalidator defines a function that is executed before checking the cache entry. It can be used to invalidate the existing cache manually by returning true. | `nil` |
| DisableValueRedaction | `bool` | Turns off cache key redaction in logs and error messages when set to `true`. | `false` |
| KeyGenerator | `func(fiber.Ctx) string` | KeyGenerator allows you to generate custom keys. The HTTP method is appended automatically. | `func(c fiber.Ctx) string { return utils.CopyString(c.Path()) }` |
| ExpirationGenerator | `func(fiber.Ctx, *cache.Config) time.Duration` | ExpirationGenerator allows you to generate custom expiration keys based on the request. | `nil` |
| Storage | `fiber.Storage` | Storage is used to store the state of the middleware. | In-memory store |
| StoreResponseHeaders | `bool` | StoreResponseHeaders allows you to store additional headers generated by next middlewares & handler. | `false` |
| MaxBytes | `uint` | MaxBytes is the maximum number of bytes of response bodies simultaneously stored in cache. | `1 * 1024 * 1024` (~1 MB) |
| Methods | `[]string` | Methods specifies the HTTP methods to cache. | `[]string{fiber.MethodGet, fiber.MethodHead}` |
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
Expiration: 5 * time.Minute,
CacheHeader: "X-Cache",
DisableCacheControl: false,
CacheInvalidator: nil,
DisableValueRedaction: false,
KeyGenerator: func(c fiber.Ctx) string {
return utils.CopyString(c.Path())
},
ExpirationGenerator: nil,
StoreResponseHeaders: false,
Storage: nil,
MaxBytes: 1 * 1024 * 1024,
Methods: []string{fiber.MethodGet, fiber.MethodHead},
}
```
================================================
FILE: docs/middleware/compress.md
================================================
---
id: compress
---
# Compress
Compression middleware for [Fiber](https://github.com/gofiber/fiber) that automatically compresses responses with `gzip`, `deflate`, `brotli`, or `zstd` based on the client's [Accept-Encoding](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding) header.
:::note
Bodies smaller than 200 bytes remain uncompressed because compression would likely increase their size and waste CPU cycles. [See the fasthttp source](https://github.com/valyala/fasthttp/blob/497922a21ef4b314f393887e9c6147b8c3e3eda4/http.go#L1713-L1715).
:::
## Behavior
- Skips compression for responses that already define `Content-Encoding`, for range requests, `206` responses, status codes without bodies, or when either side sends `Cache-Control: no-transform`.
- `HEAD` requests negotiate compression so `Content-Encoding`, `Content-Length`, `ETag`, and `Vary` reflect the encoded representation, but the body is removed before sending.
- When compression runs, strong `ETag` values are recomputed from the compressed bytes; when skipped, `Accept-Encoding` is still merged into `Vary` unless the header is `*` or already present.
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/compress"
)
```
Once your Fiber app is initialized, use the middleware like this:
```go
// Initialize default config
app.Use(compress.New())
// Or extend your config for customization
app.Use(compress.New(compress.Config{
Level: compress.LevelBestSpeed, // 1
}))
// Skip middleware for specific routes
app.Use(compress.New(compress.Config{
Next: func(c fiber.Ctx) bool {
return c.Path() == "/dont_compress"
},
Level: compress.LevelBestSpeed, // 1
}))
```
## Config
| Property | Type | Description | Default |
|:-------- |:-----------------------|:------------------------------------------------------------|:-------------------|
| Next | `func(fiber.Ctx) bool` | Skips this middleware when the function returns `true`. | `nil` |
| Level | `Level` | Compression level to use. | `LevelDefault (0)` |
Possible values for the "Level" field are:
- `LevelDisabled (-1)`: Compression is disabled.
- `LevelDefault (0)`: Default compression level.
- `LevelBestSpeed (1)`: Best compression speed.
- `LevelBestCompression (2)`: Best compression.
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
Level: LevelDefault,
}
```
## Constants
```go
// Compression levels
const (
LevelDisabled = -1
LevelDefault = 0
LevelBestSpeed = 1
LevelBestCompression = 2
)
```
================================================
FILE: docs/middleware/cors.md
================================================
---
id: cors
---
# CORS
CORS (Cross-Origin Resource Sharing) middleware for [Fiber](https://github.com/gofiber/fiber) lets servers control who can access resources and how. It isn't a security feature; it merely relaxes the browser's same-origin policy so cross-origin requests can succeed. Learn more on [MDN](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS).
It adds CORS headers to responses, listing allowed origins, methods, and headers, and handles preflight checks.
Use the `AllowOrigins` option to define which origins may send cross-origin requests. It accepts single origins, lists, subdomain patterns, wildcards, and supports dynamic validation with `AllowOriginsFunc`.
The middleware normalizes `AllowOrigins`, verifies HTTP/HTTPS schemes, and strips trailing slashes. Invalid origins cause a panic. Panic messages and logs redact misconfigured origins by default; set `DisableValueRedaction` to `true` if you need the raw value for troubleshooting.
Avoid [common pitfalls](#common-pitfalls) such as using wildcard origins with credentials, overly permissive origin lists, or skipping validation with `AllowOriginsFunc`, as misconfiguration can create security risks.
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/cors"
)
```
Once your Fiber app is initialized, apply the middleware in one of the following ways:
### Basic usage
To use the default configuration, simply use `cors.New()`. This will allow wildcard origins '*', all methods, no credentials, and no headers or exposed headers.
```go
app.Use(cors.New())
```
### Custom configuration (specific origins, headers, etc.)
```go
// Initialize default config
app.Use(cors.New())
// Or extend your config for customization
app.Use(cors.New(cors.Config{
AllowOrigins: []string{"https://gofiber.io", "https://gofiber.net"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept"},
}))
```
### Dynamic origin validation
You can use `AllowOriginsFunc` to programmatically determine whether to allow a request based on its origin. This is useful when you need to validate origins against a database or other dynamic sources. The function should return `true` if the origin is allowed, and `false` otherwise.
Be sure to review the [security considerations](#security-considerations) when using `AllowOriginsFunc`.
:::caution
Never allow `AllowOriginsFunc` to return `true` for all origins. This is particularly crucial when `AllowCredentials` is set to `true`. Doing so can bypass the restriction of using a wildcard origin with credentials, exposing your application to serious security threats.
If you need to allow wildcard origins, use `AllowOrigins` with a wildcard `"*"` instead of `AllowOriginsFunc`.
:::
```go
// dbCheckOrigin checks if the origin is in the list of allowed origins in the database.
func dbCheckOrigin(db *sql.DB, origin string) bool {
// Placeholder query - adjust according to your database schema and query needs
query := "SELECT COUNT(*) FROM allowed_origins WHERE origin = $1"
var count int
err := db.QueryRow(query, origin).Scan(&count)
if err != nil {
// Handle error (e.g., log it); for simplicity, we return false here
return false
}
return count > 0
}
// ...
app.Use(cors.New(cors.Config{
AllowOriginsFunc: func(origin string) bool {
return dbCheckOrigin(db, origin)
},
}))
```
### Prohibited usage
The following example is prohibited because it can expose your application to security risks. It sets `AllowOrigins` to `"*"` (a wildcard) and `AllowCredentials` to `true`.
```go
app.Use(cors.New(cors.Config{
AllowOrigins: []string{"*"},
AllowCredentials: true,
}))
```
This will result in the following panic:
```text
panic: [CORS] Configuration error: When 'AllowCredentials' is set to true, 'AllowOrigins' cannot contain a wildcard origin '*'. Please specify allowed origins explicitly or adjust 'AllowCredentials' setting.
```
## Config
| Property | Type | Description | Default |
|:---------------------|:----------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------|
| AllowCredentials | `bool` | AllowCredentials indicates whether or not the response to the request can be exposed when the credentials flag is true. When used as part of a response to a preflight request, this indicates whether or not the actual request can be made using credentials. Note: If true, AllowOrigins cannot be set to a wildcard (`"*"`) to prevent security vulnerabilities. | `false` |
| AllowHeaders | `[]string` | AllowHeaders defines a list of request headers that can be used when making the actual request. This is in response to a preflight request. | `[]` |
| AllowMethods | `[]string` | AllowMethods defines a list of methods allowed when accessing the resource. This is used in response to a preflight request. | `"GET, POST, HEAD, PUT, DELETE, PATCH"` |
| AllowOrigins | `[]string` | AllowOrigins defines a list of origins that may access the resource. This supports subdomain matching, so you can use a value like "https://*.example.com" to allow any subdomain of example.com to submit requests. If the special wildcard `"*"` is present in the list, all origins will be allowed. | `["*"]` |
| AllowOriginsFunc | `func(origin string) bool` | `AllowOriginsFunc` is a function that dynamically determines whether to allow a request based on its origin. If this function returns `true`, the 'Access-Control-Allow-Origin' response header will be set to the request's 'origin' header. This function is only used if the request's origin doesn't match any origin in `AllowOrigins`. | `nil` |
| AllowPrivateNetwork | `bool` | Indicates whether the `Access-Control-Allow-Private-Network` response header should be set to `true`, allowing requests from private networks. This aligns with modern security practices for web applications interacting with private networks. | `false` |
| DisableValueRedaction | `bool` | Disables redaction of misconfigured origins and settings in panics and logs. | `false` |
| ExposeHeaders | `[]string` | ExposeHeaders defines an allowlist of headers that clients are allowed to access. | `[]` |
| MaxAge | `int` | MaxAge indicates how long (in seconds) the results of a preflight request can be cached. If you pass MaxAge 0, the Access-Control-Max-Age header will not be added and the browser will use 5 seconds by default. To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header to 0. | `0` |
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` |
:::note
If AllowOrigins is a zero value `[]string{}`, and AllowOriginsFunc is provided, the middleware will not default to allowing all origins with the wildcard value "*". Instead, it will rely on the AllowOriginsFunc to dynamically determine whether to allow a request based on its origin. This provides more flexibility and control over which origins are allowed.
:::
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
AllowOriginsFunc: nil,
AllowOrigins: []string{"*"},
DisableValueRedaction: false,
AllowMethods: []string{
fiber.MethodGet,
fiber.MethodPost,
fiber.MethodHead,
fiber.MethodPut,
fiber.MethodDelete,
fiber.MethodPatch,
},
AllowHeaders: []string{},
AllowCredentials: false,
ExposeHeaders: []string{},
MaxAge: 0,
AllowPrivateNetwork: false,
}
```
## Subdomain Matching
The `AllowOrigins` configuration supports matching subdomains at any level. This means you can use a value like `"https://*.example.com"` to allow any subdomain of `example.com` to submit requests, including multiple subdomain levels such as `"https://sub.sub.example.com"`.
### Example
If you want to allow CORS requests from any subdomain of `example.com`, including nested subdomains, you can configure the `AllowOrigins` like so:
```go
app.Use(cors.New(cors.Config{
AllowOrigins: []string{"https://*.example.com"},
}))
```
## How It Works
The CORS middleware works by adding the necessary CORS headers to responses from your Fiber application. These headers tell browsers what origins, methods, and headers are allowed for cross-origin requests.
When a request arrives, the middleware first checks whether it is a preflight request—a CORS mechanism that determines if the actual request is safe to send. Preflight requests are HTTP OPTIONS requests with specific CORS headers. If the request is preflight, the middleware responds with the appropriate CORS headers and ends the request.
:::note
Preflight requests are typically sent by browsers before making actual cross-origin requests, especially for methods other than GET or POST, or when custom headers are used.
A preflight request is an HTTP OPTIONS request that includes the `Origin`, `Access-Control-Request-Method`, and optionally `Access-Control-Request-Headers` headers. The browser sends this request to check if the server allows the actual request method and headers.
:::
If the request is not preflight, the middleware adds the CORS headers to the response and passes the request to the next handler. The actual CORS headers added depend on the configuration of the middleware.
The `AllowOrigins` option controls which origins can make cross-origin requests. The middleware handles different `AllowOrigins` configurations as follows:
- **Single origin:** If `AllowOrigins` is set to a single origin like `"http://www.example.com"`, and that origin matches the origin of the incoming request, the middleware adds the header `Access-Control-Allow-Origin: http://www.example.com` to the response.
- **Multiple origins:** If `AllowOrigins` is set to multiple origins like `"https://example.com, https://www.example.com"`, the middleware picks the origin that matches the origin of the incoming request.
- **Subdomain matching:** If `AllowOrigins` includes `"https://*.example.com"`, a subdomain like `https://sub.example.com` will be matched and `"https://sub.example.com"` will be the header. This will also match `https://sub.sub.example.com` and so on, but not `https://example.com`.
- **Wildcard origin:** If `AllowOrigins` is set to `"*"`, the middleware uses that and adds the header `Access-Control-Allow-Origin: *` to the response.
In all cases above, except the **Wildcard origin**, the middleware will either add the `Access-Control-Allow-Origin` header to the response matching the origin of the incoming request, or it will not add the header at all if the origin is not allowed.
- **Programmatic origin validation:**: The middleware also handles the `AllowOriginsFunc` option, which allows you to programmatically determine if an origin is allowed. If `AllowOriginsFunc` returns `true` for an origin, the middleware sets the `Access-Control-Allow-Origin` header to that origin.
- **Null origin handling:** The middleware accepts the special literal value `"null"` as a valid origin. According to the [CORS specification](https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS#origin), browsers send `"null"` as the origin for certain privacy-sensitive contexts, such as:
- Requests from sandboxed iframes
- Requests from `file://` URLs
- Requests from `data:` URLs
- Cross-origin redirects
When using `AllowOriginsFunc`, if the function returns `true` for the literal string `"null"`, the middleware will set `Access-Control-Allow-Origin: null` in the response. The `"null"` origin is case-sensitive and must be lowercase.
The `AllowMethods` option controls which HTTP methods are allowed. For example, if `AllowMethods` is set to `"GET, POST"`, the middleware adds the header `Access-Control-Allow-Methods: GET, POST` to the response.
The `AllowHeaders` option specifies which headers are allowed in the actual request. The middleware sets the Access-Control-Allow-Headers response header to the value of `AllowHeaders`. This informs the client which headers it can use in the actual request.
The `AllowCredentials` option indicates whether the response to the request can be exposed when the credentials flag is true. If `AllowCredentials` is set to `true`, the middleware adds the header `Access-Control-Allow-Credentials: true` to the response. To prevent security vulnerabilities, `AllowCredentials` cannot be set to `true` if `AllowOrigins` is set to a wildcard (`*`).
The `ExposeHeaders` option defines an allowlist of headers that clients are allowed to access. If `ExposeHeaders` is set to `"X-Custom-Header"`, the middleware adds the header `Access-Control-Expose-Headers: X-Custom-Header` to the response.
The `MaxAge` option indicates how long the results of a preflight request can be cached. If `MaxAge` is set to `3600`, the middleware adds the header `Access-Control-Max-Age: 3600` to the response.
The `Vary` header helps caches store the correct response. For simple requests the middleware sets `Vary: Origin` unless all origins are allowed. Preflight responses add `Vary: Origin, Access-Control-Request-Method, Access-Control-Request-Headers` (and `Access-Control-Request-Private-Network` when enabled and requested). This ensures caches know when to reuse a response and when to revalidate with the server.
## Infrastructure Considerations
When deploying Fiber applications behind infrastructure components like CDNs, API gateways, load balancers, or reverse proxies, you have two main options for handling CORS:
### Option 1: Use Infrastructure-Level CORS (Recommended)
**For most production deployments, it is often preferable to handle CORS at the infrastructure level** rather than in your Fiber application. This approach offers several advantages:
- **Better Performance**: CORS headers are added at the edge, closer to the client
- **Reduced Server Load**: Preflight requests are handled without reaching your application
- **Centralized Configuration**: Manage CORS policies alongside other infrastructure settings
- **Built-in Caching**: Infrastructure providers optimize CORS response caching
**Common infrastructure CORS solutions:**
- **CDNs**: CloudFront, CloudFlare, Azure CDN - handle CORS at edge locations
- **API Gateways**: AWS API Gateway, Google Cloud API Gateway - centralized CORS management
- **Load Balancers**: Application Load Balancers with CORS rules
- **Reverse Proxies**: Nginx, Apache with CORS modules
If using infrastructure-level CORS, **disable Fiber's CORS middleware** to avoid conflicts:
```go
// Don't use both - choose one approach
// app.Use(cors.New()) // Remove this line when using infrastructure CORS
```
### Option 2: Application-Level CORS (Fiber Middleware)
Use Fiber's CORS middleware when you need:
- **Dynamic origin validation** based on application logic
- **Fine-grained control** over CORS policies per route
- **Integration with application state** (database-driven origins, etc.)
- **Development environments** where infrastructure CORS isn't available
If choosing this approach, ensure that **all CORS headers reach your Fiber application unchanged**.
### Required Headers for CORS Preflight Requests
For CORS preflight requests to work correctly, these headers **must not be stripped or modified by caching layers**:
- `Origin` - Required to identify the requesting origin
- `Access-Control-Request-Method` - Required to identify the HTTP method for the actual request
- `Access-Control-Request-Headers` - Optional, contains custom headers the actual request will use
- `Access-Control-Request-Private-Network` - Optional, for private network access requests
:::warning Critical Preflight Requirement
If the `Access-Control-Request-Method` header is missing from an OPTIONS request, Fiber will not recognize them as CORS preflight requests. Instead, they'll be treated as regular OPTIONS requests, which typically return `405 Method Not Allowed` since most applications don't define explicit OPTIONS handlers.
:::
### CORS Response Headers (Set by Fiber)
The middleware sets these response headers based on your configuration:
**For all CORS requests:**
- `Access-Control-Allow-Origin` - Set to the allowed origin or "*"
- `Access-Control-Allow-Credentials` - Set to "true" when `AllowCredentials: true`
- `Access-Control-Expose-Headers` - Lists headers the client can access
- `Vary` - Set to "Origin" (unless wildcard origins are used)
**For preflight responses only:**
- `Access-Control-Allow-Methods` - Lists allowed HTTP methods
- `Access-Control-Allow-Headers` - Lists allowed request headers (or echoes the request)
- `Access-Control-Max-Age` - Cache duration for preflight results (if MaxAge > 0)
- `Access-Control-Allow-Private-Network` - Set to "true" when private network access is allowed
- `Vary` - Set to "Access-Control-Request-Method, Access-Control-Request-Headers, Origin"
### Common Infrastructure Issues
**CDNs (CloudFront, CloudFlare, etc.)**:
- Configure cache policies to forward all CORS headers
- Ensure OPTIONS requests are not cached inappropriately or cache them correctly with proper Vary headers
- Don't strip or modify CORS request headers
**API Gateways**:
- Choose either gateway-level CORS OR application-level CORS, not both
- If using gateway CORS, disable Fiber's CORS middleware
- If forwarding to Fiber, ensure all headers pass through unchanged
**Load Balancers/Reverse Proxies**:
- Preserve all HTTP headers, especially CORS-related ones
- Don't modify or strip `Origin`, `Access-Control-Request-*` headers
**WAFs/Security Services**:
- Whitelist CORS headers in security rules
- Ensure OPTIONS requests with CORS headers aren't blocked
### Debugging CORS Issues
Add this middleware **before** your CORS configuration to debug what headers Fiber receives:
```go
// Debug middleware to log CORS preflight requests
// Only use in development or testing environments
app.Use(func(c *fiber.Ctx) error {
if c.Method() == "OPTIONS" {
fmt.Printf("OPTIONS %s\n", c.Path())
fmt.Printf(" Origin: %s\n", c.Get("Origin"))
fmt.Printf(" Access-Control-Request-Method: %s\n", c.Get("Access-Control-Request-Method"))
fmt.Printf(" Access-Control-Request-Headers: %s\n", c.Get("Access-Control-Request-Headers"))
}
return c.Next()
})
app.Use(cors.New(cors.Config{
AllowOrigins: []string{"https://yourdomain.com"},
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
}))
```
Test CORS preflight directly with curl:
```bash
# Test preflight request
curl -X OPTIONS https://your-app.com/api/test \
-H "Origin: https://yourdomain.com" \
-H "Access-Control-Request-Method: POST" \
-H "Access-Control-Request-Headers: Content-Type" \
-v
# Test simple CORS request
curl -X GET https://your-app.com/api/test \
-H "Origin: https://yourdomain.com" \
-v
```
### Caching Considerations
The middleware sets appropriate `Vary` headers to ensure proper caching:
- **Non-wildcard origins**: `Vary: Origin` is set to cache responses per origin
- **Preflight requests**: `Vary: Access-Control-Request-Method, Access-Control-Request-Headers, Origin`
- **OPTIONS without preflight headers**: `Vary: Origin` to avoid cache poisoning
Ensure your infrastructure respects these `Vary` headers for correct caching behavior.
### Choosing the Right Approach
| Scenario | Recommended Approach |
|----------|---------------------|
| Production with CDN/API Gateway | Infrastructure-level CORS |
| Dynamic origin validation needed | Application-level CORS |
| Microservices with different CORS policies | Application-level CORS |
| Simple static origins | Infrastructure-level CORS |
| Development/testing | Application-level CORS |
| High traffic applications | Infrastructure-level CORS |
:::tip Infrastructure CORS Configuration
Most cloud providers offer comprehensive CORS documentation:
- [AWS CloudFront CORS](https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/header-caching.html#header-caching-web-cors)
- [Google Cloud CORS](https://cloud.google.com/storage/docs/cross-origin)
- [Azure CDN CORS](https://docs.microsoft.com/en-us/azure/cdn/cdn-cors)
- [CloudFlare CORS](https://developers.cloudflare.com/fundamentals/get-started/reference/http-request-headers/#cf-connecting-ip)
Configure CORS at the infrastructure level when possible for optimal performance and reduced complexity.
:::
## Security Considerations
When configuring CORS, misconfiguration can potentially expose your application to various security risks. Here are some secure configurations and common pitfalls to avoid:
### Secure Configurations
- **Specify Allowed Origins**: Instead of using a wildcard (`"*"`), specify the exact domains allowed to make requests. For example, `AllowOrigins: "https://www.example.com, https://api.example.com"` ensures only these domains can make cross-origin requests to your application.
- **Use Credentials Carefully**: If your application needs to support credentials in cross-origin requests, ensure `AllowCredentials` is set to `true` and specify exact origins in `AllowOrigins`. Do not use a wildcard origin in this case.
- **Limit Exposed Headers**: Only allowlist headers that are necessary for the client-side application by setting `ExposeHeaders` appropriately. This minimizes the risk of exposing sensitive information.
### Common Pitfalls
- **Wildcard Origin with Credentials**: Setting `AllowOrigins` to `"*"` (a wildcard) and `AllowCredentials` to `true` is a common misconfiguration. This combination is prohibited because it can expose your application to security risks.
- **Overly Permissive Origins**: Specifying too many origins or using overly broad patterns (e.g., `https://*.example.com`) can inadvertently allow malicious sites to interact with your application. Be as specific as possible with allowed origins.
- **Inadequate `AllowOriginsFunc` Validation**: When using `AllowOriginsFunc` for dynamic origin validation, ensure the function includes robust checks to prevent unauthorized origins from being accepted. Overly permissive validation can lead to security vulnerabilities. Never allow `AllowOriginsFunc` to return `true` for all origins. This is particularly crucial when `AllowCredentials` is set to `true`. Doing so can bypass the restriction of using a wildcard origin with credentials, exposing your application to serious security threats. If you need to allow wildcard origins, use `AllowOrigins` with a wildcard `"*"` instead of `AllowOriginsFunc`.
Remember, the key to secure CORS configuration is specificity and caution. By carefully selecting which origins, methods, and headers are allowed, you can help protect your application from cross-origin attacks.
================================================
FILE: docs/middleware/csrf.md
================================================
---
id: csrf
---
# CSRF
The CSRF middleware protects against [Cross-Site Request Forgery](https://en.wikipedia.org/wiki/Cross-site_request_forgery) attacks by validating tokens on unsafe HTTP methods such as POST, PUT, and DELETE. It responds with 403 Forbidden when validation fails.
## Table of Contents
- [Quick Start](#quick-start)
- [Best Practices & Production Requirements](#best-practices--production-requirements)
- [Configuration by Application Type](#configuration-by-application-type)
- [Recipes for Common Use Cases](#recipes-for-common-use-cases)
- [Using CSRF Tokens](#using-csrf-tokens)
- [Security Model](#security-model)
- [Token Extractors](#token-extractors)
- [Advanced Configuration](#advanced-configuration)
- [API Reference](#api-reference)
- [Config Properties](#config-properties)
- [Error Types](#error-types)
- [Constants](#constants)
## Quick Start
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
"github.com/gofiber/fiber/v3/middleware/csrf"
)
// Default config (development only)
app.Use(csrf.New())
// Production config
app.Use(csrf.New(csrf.Config{
CookieName: "__Host-csrf_",
CookieSecure: true,
CookieHTTPOnly: true, // false for SPAs
CookieSameSite: "Lax",
CookieSessionOnly: true,
Extractor: extractors.FromHeader("X-Csrf-Token"),
Session: sessionStore,
// Redaction is enabled by default. Set DisableValueRedaction when you must expose tokens or storage keys in diagnostics.
// DisableValueRedaction: true,
}))
```
## Best Practices & Production Requirements
:::danger Production Requirements
- `CookieSecure: true` (HTTPS only)
- `CookieSameSite: "Lax"` or `"Strict"`
- Use `Session` store for better security
:::
1. **Always use HTTPS** in production
2. **Use sessions** for authenticated applications
3. **Set `CookieSecure: true`** and appropriate SameSite values
4. **Implement XSS protection** alongside CSRF
5. **Regenerate tokens** after auth changes
6. **Use `__Host-` cookie prefix** when possible
:::warning BREACH Protection
To mitigate BREACH attacks, ensure your pages are served over HTTPS, disable HTTP compression, and implement rate limiting for requests. The CSRF token is sent as a header on every request, so if you include the token in a page that is vulnerable to BREACH, an attacker may be able to extract the token.
:::
## Configuration by Application Type
### Server-Side Rendered Apps
```go
app.Use(csrf.New(csrf.Config{
CookieName: "__Host-csrf_",
CookieSecure: true,
CookieHTTPOnly: true, // Secure - blocks JavaScript
CookieSameSite: "Lax",
CookieSessionOnly: true,
Extractor: extractors.FromForm("_csrf"),
Session: sessionStore,
}))
```
### Single Page Applications (SPAs)
```go
app.Use(csrf.New(csrf.Config{
CookieName: "__Host-csrf_",
CookieSecure: true,
CookieHTTPOnly: false, // Required for JavaScript access to tokens
CookieSameSite: "Lax",
CookieSessionOnly: true,
Extractor: extractors.FromHeader("X-Csrf-Token"),
Session: sessionStore,
}))
```
:::warning SPA Security Trade-off
SPAs require `CookieHTTPOnly: false` to access tokens via JavaScript. This slightly increases XSS risk but is necessary for SPA functionality.
:::
## Recipes for Common Use Cases
- **Without Sessions**: [CSRF Recipe](https://github.com/gofiber/recipes/tree/master/csrf) - Simple Double Submit Cookie pattern
- **With Sessions**: [CSRF with Session Recipe](https://github.com/gofiber/recipes/tree/master/csrf-with-session) - More secure Synchronizer Token pattern
## Using CSRF Tokens
### Server-Side Forms
```go
func formHandler(c fiber.Ctx) error {
token := csrf.TokenFromContext(c)
return c.SendString(fmt.Sprintf(`
`, token))
}
```
### Single Page Applications
```go
func apiHandler(c fiber.Ctx) error {
token := csrf.TokenFromContext(c)
return c.JSON(fiber.Map{
"csrf_token": token,
"data": "your data",
})
}
```
```javascript
// Get CSRF token from cookie
function getCsrfToken() {
const value = `; ${document.cookie}`;
const parts = value.split(`; __Host-csrf_=`);
if (parts.length === 2) return parts.pop().split(';').shift();
}
// Use with fetch API
async function makeRequest(url, data) {
const csrfToken = getCsrfToken();
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-Csrf-Token': csrfToken
},
body: JSON.stringify(data)
});
if (!response.ok) {
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
}
return response.json();
}
```
## Security Model
The middleware employs a robust, defense-in-depth strategy to protect against CSRF attacks. The primary defense is token-based validation, which operates in one of two modes depending on your configuration. This is supplemented by a mandatory secondary check on the request's origin.
### Fetch Metadata Guardrails
- **Sec-Fetch-Site**: For unsafe methods, the middleware inspects the [`Sec-Fetch-Site`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site) header when present. If the header value is not one of "same-origin", "none", "same-site", or "cross-site", the request is rejected with `ErrFetchSiteInvalid`. If the header is valid or absent, the request proceeds to the standard origin and token validation checks. This provides an early check to block requests with invalid `Sec-Fetch-Site` values, while allowing legitimate same-site and cross-site requests to be validated by the existing mechanisms.
### 1. Token Validation Patterns
#### Double Submit Cookie (Default Mode)
This is the default pattern, used when a `Session` store is **not** configured. It is a "semi-stateless" approach; while it doesn't tie tokens to a specific user session, the server still maintains a record of all validly issued tokens.
- **How it Works:**
1. On a user's first visit (or a safe request like `GET`), the middleware generates a unique token.
2. This token is sent to the client in a `Set-Cookie` header.
3. The server also stores this token (in memory by default or in the configured `Storage`). It confirms the token is server-generated and still valid, but it is not tied to a specific user.
4. For subsequent unsafe requests (e.g., `POST`, `PUT`), the client must read the token from the cookie and echo it in a different location, such as the `X-Csrf-Token` header.
- **Validation:** The middleware validates three things: that the token from the header/form **exactly matches** the token from the cookie, that the token **exists** in the server-side storage, and that it **has not expired**.
- **Why it is secure:** Attackers on a malicious domain cannot read the victim's cookie to forge a matching header. They also cannot invent a token because it wouldn't exist in the server's storage registry.
#### Synchronizer Token (Session-Based Mode)
This is a more secure, stateful pattern that is **automatically enabled** when you provide a `Session` store in the configuration.
- **How it Works:**
1. A unique token is generated and stored directly within the user's session data on the server.
2. The token is also sent to the client as a cookie.
3. For unsafe requests, the client sends the token back in a header or form field.
- **Validation:** The middleware performs a multi-step validation:
1. It first performs the standard **Double Submit Cookie check**: the token from the header/form must exactly match the token from the cookie. This is a fast and efficient first line of defense, and there is little benefit of skipping it.
2. It then validates that this token exists and is valid within the user's **server-side session**. This is the authoritative check that ties the token to the authenticated user.
- **Why it is more secure:** Tying the token to the server-side session provides the strongest CSRF protection, as the token is then guaranteed to have been generated for the specific user. While browsers automatically send the required cookie, custom API clients must remember to include the cookie with their requests for validation to succeed.
```go
// Enable the more secure Synchronizer Token pattern
app.Use(csrf.New(csrf.Config{
Session: sessionStore, // Providing a session store activates this mode
}))
```
### 2. Origin & Referer Validation
As a crucial second layer of defense, the middleware **always** performs `Origin` and `Referer` header checks for unsafe requests (when the connection is HTTPS).
- The request's `Origin` (for cross-origin requests) or `Referer` (for same-origin requests) header **must** match the application's `Host` header or be explicitly allowed in the `TrustedOrigins` list.
- This check is performed *in addition* to token validation and provides strong protection because these headers are reliably set by browsers and cannot be programmatically controlled by an attacker from a malicious site.
## Token Extractors
This middleware uses the shared `extractors` package for token extraction. For full details on extractor types, chaining, security, and advanced usage, see the [Extractors Guide](../guide/extractors).
**Extractor Source Constants:**
Extractor source constants (such as `SourceHeader`, `SourceForm`, etc.) are defined in the shared extractors package, not in the CSRF middleware itself. Refer to the Extractors Guide for their definitions and usage.
### CSRF-Specific Extractor Notes
For CSRF protection, prefer secure extraction methods:
- **Headers** (`extractors.FromHeader("X-Csrf-Token")`) – Most secure, not logged in URLs
- **Form data** (`extractors.FromForm("_csrf")`) – Secure for form submissions
- **Avoid URL parameters** – Query/param extractors expose tokens in logs and browser history
:::note What about cookies?
**Cookies are generally not a secure source for CSRF tokens.** The middleware will panic if you configure an extractor that reads from cookies with the same name as your CSRF cookie. This is because reading the CSRF token from a cookie with the same name as the CSRF cookie defeats CSRF protection entirely, as the extracted token will always match the cookie value, allowing any CSRF attack to succeed.
**Advanced usage:**
In rare cases, you may securely extract a CSRF token from a cookie if:
- You read from a different cookie (not the CSRF cookie itself)
- You use multiple cookies for custom validation
- You implement custom logic across different cookie sources
If you do this, set the extractor’s `Source` to `SourceCookie` and allow the middleware to check that the cookie name is different from your CSRF cookie. It will panic if this is the case.
**Warning:**
Cookie-based extraction is strongly discouraged, as it is easy to misconfigure and creates security risks. Prefer extracting tokens from headers or form fields for robust CSRF protection. See the [Extractors Guide](../guide/extractors#security-considerations) for more details.
:::
### Route-Specific Configuration
You can configure different extraction methods for different routes:
```go
// API routes - header extraction for AJAX/fetch requests
api := app.Group("/api")
api.Use(csrf.New(csrf.Config{
Extractor: extractors.FromHeader("X-Csrf-Token"),
}))
// Form routes - form field extraction for traditional forms
forms := app.Group("/forms")
forms.Use(csrf.New(csrf.Config{
Extractor: extractors.FromForm("_csrf"),
}))
```
### Custom CSRF Extractors
For specialized CSRF token extraction needs, you can create custom extractors. See the [Extractors Guide](../guide/extractors#custom-extraction-logic) for advanced patterns and security notes.
:::danger Never Extract from Cookies
**NEVER create custom extractors that read from cookies using the same `CookieName` as your CSRF configuration.** This completely defeats CSRF protection by making the extracted token always match the cookie value, allowing any CSRF attack to succeed.
```go
// ❌ NEVER DO THIS - Completely defeats CSRF protection
badExtractor := csrf.Extractor{
Extract: func(c fiber.Ctx) (string, error) {
return c.Cookies("csrf_"), nil // Always passes validation!
},
Source: csrf.SourceCustom, // See extractors.SourceCustom in shared package
Key: "csrf_",
}
// ✅ DO THIS - Extract from different source than cookie
app.Use(csrf.New(csrf.Config{
CookieName: "csrf_",
Extractor: extractors.FromHeader("X-Csrf-Token"), // Header vs cookie comparison
}))
```
The middleware uses the **Double Submit Cookie** pattern – it compares the extracted token against the cookie value. If you configure an extractor that reads from the same cookie, it will panic because they will always match and provide zero CSRF protection.
:::
#### Bearer Token Embedding & Custom Extractors
You can create advanced extractors for use cases like JWT embedding or JSON body parsing. See the [Extractors Guide](../guide/extractors#custom-extraction-logic) for secure implementation patterns and more examples.
### Fallback Extraction
For applications that need to support both AJAX and form submissions:
```go
// Try header first (AJAX), fall back to form (traditional forms)
app.Use(csrf.New(csrf.Config{
Extractor: extractors.Chain(
extractors.FromHeader("X-Csrf-Token"),
extractors.FromForm("_csrf"),
),
}))
```
:::warning
Chaining extractors increases complexity. Use only when you need to support multiple client types. See the [Extractors Guide](../guide/extractors#chain-ordering-strategy) for details and security notes.
:::
## Advanced Configuration
### Trusted Origins
```go
app.Use(csrf.New(csrf.Config{
TrustedOrigins: []string{
"https://trusted.example.com",
"https://*.example.com", // Wildcard subdomains
},
}))
```
### Custom Error Handler
```go
app.Use(csrf.New(csrf.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
accepts := c.Accepts("html", "json")
path := c.Path()
if accepts == "json" || strings.HasPrefix(path, "/api/") {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "Forbidden",
})
}
return c.Status(fiber.StatusForbidden).Render("error", fiber.Map{
"Title": "Forbidden",
"Status": fiber.StatusForbidden,
}, "layouts/main")
},
}))
```
### Custom Storage/Database
You can use any storage from our [storage](https://github.com/gofiber/storage/) package.
```go
storage := sqlite3.New() // From github.com/gofiber/storage/sqlite3/v2
app.Use(csrf.New(csrf.Config{
Storage: storage,
}))
```
### Token Management
```go
// Delete token (e.g., on logout)
handler := csrf.HandlerFromContext(c)
if handler != nil {
if err := handler.DeleteToken(c); err != nil {
// handle error, e.g. log it
}
}
// With session middleware
// Destroying the session will also remove the CSRF token if using session-based CSRF.
session.Destroy()
```
## API Reference
```go
// Create middleware
func New(config ...csrf.Config) fiber.Handler
// Get token from context
func TokenFromContext(ctx any) string
// Get handler from context
func HandlerFromContext(ctx any) *csrf.Handler
// Delete token
func (h *csrf.Handler) DeleteToken(c fiber.Ctx) error
```
`TokenFromContext` and `HandlerFromContext` accept a `fiber.CustomCtx`, `fiber.Ctx`, a `*fasthttp.RequestCtx`, or a `context.Context`.
## Config Properties
| Property | Type | Description | Default |
|:------------------|:-----------------------------------|:------------------------------------------------------------------------------------------------------------------------------|:-----------------------------|
| Next | `func(fiber.Ctx) bool` | Skip middleware when returns true | `nil` |
| CookieName | `string` | CSRF cookie name | `"csrf_"` |
| CookieDomain | `string` | CSRF cookie domain | `""` |
| CookiePath | `string` | CSRF cookie path | `""` |
| CookieSecure | `bool` | HTTPS only cookie (**required for production**) | `false` |
| CookieHTTPOnly | `bool` | Prevent JavaScript access (**use `false` for SPAs**) | `false` |
| CookieSameSite | `string` | SameSite attribute (**use "Lax" or "Strict"**) | `"Lax"` |
| CookieSessionOnly | `bool` | Session-only cookie (expires on browser close) | `false` |
| IdleTimeout | `time.Duration` | Token expiration time | `30 * time.Minute` |
| KeyGenerator | `func() string` | Token generation function | `utils.SecureToken` |
| ErrorHandler | `fiber.ErrorHandler` | Custom error handler | `defaultErrorHandler` |
| Extractor | `extractors.Extractor` | Token extraction method with metadata | `extractors.FromHeader("X-Csrf-Token")` |
| DisableValueRedaction | `bool` | Disables redaction of tokens and storage keys in logs and error messages. | `false` |
| Session | `*session.Store` | Session store (**recommended for production**) | `nil` |
| Storage | `fiber.Storage` | Token storage (overridden by Session) | `nil` |
| TrustedOrigins | `[]string` | Trusted origins for cross-origin requests | `[]` |
| SingleUseToken | `bool` | Generate new token after each use | `false` |
## Error Types
```go
var (
ErrTokenNotFound = errors.New("csrf: token not found")
ErrTokenInvalid = errors.New("csrf: token invalid")
ErrRefererNotFound = errors.New("csrf: referer header missing")
ErrRefererInvalid = errors.New("csrf: referer header invalid")
ErrRefererNoMatch = errors.New("csrf: referer does not match host or trusted origins")
ErrOriginInvalid = errors.New("csrf: origin header invalid")
ErrOriginNoMatch = errors.New("csrf: origin does not match host or trusted origins")
)
```
## Constants
```go
const (
HeaderName = "X-Csrf-Token"
)
```
================================================
FILE: docs/middleware/earlydata.md
================================================
---
id: earlydata
---
# EarlyData
The Early Data middleware adds TLS 1.3 "0-RTT" support to [Fiber](https://github.com/gofiber/fiber). When the client and server share a PSK, TLS 1.3 lets the client send data with the first flight and skip the initial round trip.
Enable Fiber's `TrustProxy` option before using this middleware to avoid spoofed client headers.
When `TrustProxy` is disabled (the default) or the remote address is not trusted by your proxy configuration, requests carrying the `Early-Data` header are rejected with `425 Too Early` to prevent 0-RTT spoofing from direct clients.
Enabling early data in a reverse proxy (for example, `ssl_early_data on;` in nginx) makes requests replayable. Review these resources before proceeding:
- [datatracker](https://datatracker.ietf.org/doc/html/rfc8446#section-8)
- [trailofbits](https://blog.trailofbits.com/2019/03/25/what-application-developers-need-to-know-about-tls-early-data-0rtt)
By default, the middleware permits early data only for safe methods (`GET`, `HEAD`, `OPTIONS`, `TRACE`) and rejects other requests before your handler runs. Override this behavior with the `AllowEarlyData` option.
## Signatures
```go
func New(config ...Config) fiber.Handler
func IsEarly(c fiber.Ctx) bool
```
`IsEarly` returns `true` when a request used early data and the middleware allowed it to proceed.
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/earlydata"
)
```
Once your Fiber app is initialized, use the middleware like this:
```go
// Initialize default config
app.Use(earlydata.New())
// Or extend your config for customization
app.Use(earlydata.New(earlydata.Config{
Error: fiber.ErrTooEarly,
// ...
}))
```
## Config
| Property | Type | Description | Default |
|:---------------|:------------------------|:-----------|:-------------------------------------------------------|
| Next | `func(fiber.Ctx) bool` | Skip this middleware when the function returns true. | `nil` |
| IsEarlyData | `func(fiber.Ctx) bool` | Reports whether the request used early data. | Function checking if "Early-Data" header equals "1" |
| AllowEarlyData | `func(fiber.Ctx) bool` | Decides if an early-data request should be allowed. | Function rejecting on unsafe and allowing safe methods |
| Error | `error` | Returned when an early-data request is rejected. | `fiber.ErrTooEarly` |
## Default Config
```go
var ConfigDefault = Config{
IsEarlyData: func(c fiber.Ctx) bool {
return c.Get(DefaultHeaderName) == DefaultHeaderTrueValue
},
AllowEarlyData: func(c fiber.Ctx) bool {
return fiber.IsMethodSafe(c.Method())
},
Error: fiber.ErrTooEarly,
}
```
## Constants
```go
const (
DefaultHeaderName = "Early-Data"
DefaultHeaderTrueValue = "1"
)
```
================================================
FILE: docs/middleware/encryptcookie.md
================================================
---
id: encryptcookie
---
# Encrypt Cookie
The Encrypt Cookie middleware for [Fiber](https://github.com/gofiber/fiber) encrypts cookie values for secure storage.
:::note
This middleware encrypts cookie values but not cookie names.
:::
## Signatures
```go
// Initializes the middleware
func New(config ...Config) fiber.Handler
// GenerateKey returns a random string of 16, 24, or 32 bytes.
// The length of the key determines the AES encryption algorithm used:
// 16 bytes for AES-128, 24 bytes for AES-192, and 32 bytes for AES-256-GCM.
func GenerateKey(length int) string
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/encryptcookie"
)
```
Once your Fiber app is initialized, register the middleware:
```go
// Provide a minimal configuration
app.Use(encryptcookie.New(encryptcookie.Config{
Key: "secret-32-character-string",
}))
// Retrieve the encrypted cookie value
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("value=" + c.Cookies("test"))
})
// Create an encrypted cookie
app.Post("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
```
:::note
Use an encoded key of 16, 24, or 32 bytes to select AES‑128, AES‑192, or AES‑256‑GCM. Generate a stable key with `openssl rand -base64 32` or `encryptcookie.GenerateKey(32)` and store it securely. Generating a new key on each startup renders existing cookies unreadable.
:::
## Config
| Property | Type | Description | Default |
|:----------|:----------------------------------------------------|:------------------------------------------------------------------------------------------------------|:-----------------------------|
| Next | `func(fiber.Ctx) bool` | A function to skip this middleware when it returns true. | `nil` |
| Except | `[]string` | Array of cookie keys that should not be encrypted. | `[]` |
| Key | `string` | A base64-encoded unique key to encode & decode cookies. Required. Key length should be 16, 24, or 32 bytes. | (No default, required field) |
| Encryptor | `func(name, decryptedString, key string) (string, error)` | A custom function to encrypt cookies. | `EncryptCookie` |
| Decryptor | `func(name, encryptedString, key string) (string, error)` | A custom function to decrypt cookies. | `DecryptCookie` |
### Encryptor and Decryptor parameters
Custom encryptor and decryptor functions receive three arguments:
- `name`: The cookie name. The default helpers bind this value as additional authenticated data (AAD) so encrypted values can only be decrypted for the same cookie.
- `string`: The cookie payload. `EncryptCookie` accepts the decrypted value and returns ciphertext, while `DecryptCookie` receives ciphertext and must return the decrypted value.
- `key`: The base64-encoded key pulled from the middleware configuration. Use it to derive or validate any encryption keys your implementation requires.
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
Except: []string{},
Key: "",
Encryptor: EncryptCookie,
Decryptor: DecryptCookie,
}
```
## Use with Other Middleware That Reads or Modifies Cookies
Place `encryptcookie` before middleware that reads or writes cookies. If you use the CSRF middleware, register `encryptcookie` first so it can read the token.
Exclude cookies from encryption by listing them in `Except`. If a frontend framework such as Angular reads the CSRF token from a cookie, add that name to the `Except` array:
```go
app.Use(encryptcookie.New(encryptcookie.Config{
Key: "secret-thirty-2-character-string",
Except: []string{csrf.ConfigDefault.CookieName}, // exclude CSRF cookie
}))
app.Use(csrf.New(csrf.Config{
Extractor: csrf.FromHeader(csrf.HeaderName),
CookieSameSite: "Lax",
CookieSecure: true,
CookieHTTPOnly: false,
}))
```
## Encryption Algorithms
The default Encryptor and Decryptor functions use `AES-256-GCM` for encryption and decryption. If you need to use `AES-128` or `AES-192` instead, you can do so by changing the length of the key when calling `encryptcookie.GenerateKey(length)` or by providing a key of one of the following lengths:
- AES-128 requires a 16-byte key.
- AES-192 requires a 24-byte key.
- AES-256 requires a 32-byte key.
For example, to generate a key for AES-128:
```go
key := encryptcookie.GenerateKey(16)
```
And for AES-192:
```go
key := encryptcookie.GenerateKey(24)
```
================================================
FILE: docs/middleware/envvar.md
================================================
---
id: envvar
---
# EnvVar
EnvVar middleware for [Fiber](https://github.com/gofiber/fiber) exposes environment variables with configurable options.
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/envvar"
)
```
Once your Fiber app is initialized, configure the middleware as shown:
```go
// Initialize default config (exports no variables)
app.Use("/expose/envvars", envvar.New())
// Or extend your config for customization
app.Use("/expose/envvars", envvar.New(
envvar.Config{
ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"},
}),
)
```
:::note
Mount the middleware on a path; it cannot be used without one.
:::
## Response
Sample response:
```json
{
"vars": {
"someEnvVariable": "someValue",
"anotherEnvVariable": "anotherValue"
}
}
```
## Config
| Property | Type | Description | Default |
|:------------|:--------------------|:-----------------------------------------------------------------------------|:--------|
| ExportVars | `map[string]string` | ExportVars lists the environment variables to expose. | `nil` |
## Default Config
```go
Config{}
// Exports no environment variables
```
================================================
FILE: docs/middleware/etag.md
================================================
---
id: etag
---
# ETag
ETag middleware for [Fiber](https://github.com/gofiber/fiber) that helps caches validate responses and saves bandwidth by avoiding full retransmits when content is unchanged.
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/etag"
)
```
Once your Fiber app is initialized, use the middleware like this:
```go
// Initialize default config
app.Use(etag.New())
// GET / -> ETag: "13-1831710635"
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
// Or extend your config for customization
app.Use(etag.New(etag.Config{
Weak: true,
}))
// GET / -> ETag: W/"13-1831710635"
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
```
Entity tags in requests must be quoted per RFC 9110. For example:
```text
If-None-Match: "example-etag"
```
## Config
| Property | Type | Description | Default |
|:---------|:------------------------|:-------------------------------------------------------------------------------------------------------------------|:--------|
| Weak | `bool` | Enables weak validators. Weak ETags are easier to generate but less reliable for comparisons. | `false` |
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` |
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
Weak: false,
}
```
================================================
FILE: docs/middleware/expvar.md
================================================
---
id: expvar
---
# ExpVar
The ExpVar middleware exposes runtime variables over HTTP in JSON. Using it (e.g., `app.Use(expvarmw.New())`) registers handlers on `/debug/vars`.
## Signatures
```go
func New() fiber.Handler
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
expvarmw "github.com/gofiber/fiber/v3/middleware/expvar"
)
```
Once your Fiber app is initialized, use the middleware as shown:
```go
var count = expvar.NewInt("count")
app.Use(expvarmw.New())
app.Get("/", func(c fiber.Ctx) error {
count.Add(1)
return c.SendString(fmt.Sprintf("hello expvar count %d", count.Value()))
})
```
Visit `/debug/vars` to see all variables, and append `?r=key` to filter the output.
```bash
curl 127.0.0.1:3000
hello expvar count 1
curl 127.0.0.1:3000/debug/vars
{
"cmdline": ["xxx"],
"count": 1,
"expvarHandlerCalls": 33,
"expvarRegexpErrors": 0,
"memstats": {...}
}
curl 127.0.0.1:3000/debug/vars?r=c
{
"cmdline": ["xxx"],
"count": 1
}
```
## Config
| Property | Type | Description | Default |
|:---------|:------------------------|:--------------------------------------------------------------------|:--------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` |
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
}
```
================================================
FILE: docs/middleware/favicon.md
================================================
---
id: favicon
---
# Favicon
Favicon middleware for [Fiber](https://github.com/gofiber/fiber) that drops repeated `/favicon.ico` requests or serves a cached icon from memory. Mount it before your logger to suppress noisy requests and avoid disk reads.
It handles only `GET`, `HEAD`, and `OPTIONS` to the configured URL; other methods return `405 Method Not Allowed`.
:::note
This middleware only serves the default `/favicon.ico` (or a [custom URL](#config)). For multiple icons, use the Static middleware.
:::
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/favicon"
)
```
Once your Fiber app is initialized, use the middleware like this:
```go
// Initialize default config
app.Use(favicon.New())
// Or extend your config for customization
app.Use(favicon.New(favicon.Config{
File: "./favicon.ico",
URL: "/favicon.ico",
}))
```
## Config
| Property | Type | Description | Default |
|:-------------|:------------------------|:---------------------------------------------------------------------------------|:---------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` |
| Data | `[]byte` | Raw data of the favicon file. This can be used instead of `File`. | `nil` |
| File | `string` | File holds the path to an actual favicon that will be cached. | "" |
| URL | `string` | URL for favicon handler. | "/favicon.ico" |
| FileSystem | `fs.FS` | FileSystem is an optional alternate filesystem from which to load the favicon file (e.g. using `os.DirFS` or an `embed.FS`). | `nil` |
| CacheControl | `string` | CacheControl defines how the Cache-Control header in the response should be set. | "public, max-age=31536000" |
| MaxBytes | `int64` | MaxBytes limits the maximum size of the cached favicon asset. | `1048576` |
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
File: "",
URL: fPath,
CacheControl: "public, max-age=31536000",
MaxBytes: 1024 * 1024,
}
```
================================================
FILE: docs/middleware/healthcheck.md
================================================
---
id: healthcheck
---
# Health Check
Middleware that adds liveness, readiness, and startup probes to [Fiber](https://github.com/gofiber/fiber) apps. It provides a generic handler you can mount on any route, with constants for the conventional `/livez`, `/readyz`, and `/startupz` endpoints.
## Overview
Register the middleware on any endpoint you want to expose a probe on. The package exports constants for the conventional liveness, readiness, and startup endpoints:
```go
app.Get(healthcheck.LivenessEndpoint, healthcheck.New())
app.Get(healthcheck.ReadinessEndpoint, healthcheck.New())
app.Get(healthcheck.StartupEndpoint, healthcheck.New())
```
By default the probe returns `true`, so each endpoint responds with `200 OK`; returning `false` yields `503 Service Unavailable`.
- **Liveness**: Checks if the server is running.
- **Readiness**: Checks if the application is ready to handle requests.
- **Startup**: Checks if the application has completed its startup sequence.
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/healthcheck"
)
```
After your app is initialized, register the middleware on the endpoints you want to expose:
```go
// Use the default probe on the conventional endpoints
app.Get(healthcheck.LivenessEndpoint, healthcheck.New())
app.Get(healthcheck.ReadinessEndpoint, healthcheck.New(healthcheck.Config{
Probe: func(c fiber.Ctx) bool {
return serviceA.Ready() && serviceB.Ready()
},
}))
app.Get(healthcheck.StartupEndpoint, healthcheck.New())
// Register a custom endpoint
app.Get("/healthz", healthcheck.New())
```
The middleware responds only to GET. Use `app.All` to expose a probe on every method; other methods fall through to the next handler:
```go
app.All("/healthz", healthcheck.New())
```
## Config
```go
type Config struct {
// Next defines a function to skip this middleware when it returns true. If this function returns true
// and no other handlers are defined for the route, Fiber will return a status 404 Not Found, since
// no other handlers were defined to return a different status.
//
// Optional. Default: nil
Next func(fiber.Ctx) bool
// Probe is executed to determine the current health state. It can be used for
// liveness, readiness or startup checks. Returning true indicates the application
// is healthy.
//
// Optional. Default: func(c fiber.Ctx) bool { return true }
Probe func(fiber.Ctx) bool
}
```
## Default Config
The default configuration used by this middleware is defined as follows:
```go
func defaultProbe(_ fiber.Ctx) bool { return true }
var ConfigDefault = Config{
Next: nil,
Probe: defaultProbe,
}
```
================================================
FILE: docs/middleware/helmet.md
================================================
---
id: helmet
---
# Helmet
Helmet secures your app by adding common security headers.
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
Once your Fiber app is initialized, add the middleware:
```go
package main
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/helmet"
)
func main() {
app := fiber.New()
app.Use(helmet.New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Welcome!")
})
app.Listen(":3000")
}
```
## Test
```bash
curl -I http://localhost:3000
```
## Config
| Property | Type | Description | Default |
|:--------------------------|:------------------------|:--------------------------------------------|:-----------------|
| Next | `func(fiber.Ctx) bool` | Skips the middleware when the function returns `true`. | `nil` |
| XSSProtection | `string` | Value for the `X-XSS-Protection` header. | "0" |
| ContentTypeNosniff | `string` | Value for the `X-Content-Type-Options` header. | "nosniff" |
| XFrameOptions | `string` | Value for the `X-Frame-Options` header. | "SAMEORIGIN" |
| HSTSMaxAge | `int` | `max-age` value for `Strict-Transport-Security`. | 0 |
| HSTSExcludeSubdomains | `bool` | Disables HSTS on subdomains when `true`. | false |
| ContentSecurityPolicy | `string` | Value for the `Content-Security-Policy` header. | "" |
| CSPReportOnly | `bool` | Enables report-only mode for CSP. | false |
| HSTSPreloadEnabled | `bool` | Adds the `preload` directive to HSTS. | false |
| ReferrerPolicy | `string` | Value for the `Referrer-Policy` header. | "no-referrer" |
| PermissionPolicy | `string` | Value for the `Permissions-Policy` header. | "" |
| CrossOriginEmbedderPolicy | `string` | Value for the `Cross-Origin-Embedder-Policy` header. | "require-corp" |
| CrossOriginOpenerPolicy | `string` | Value for the `Cross-Origin-Opener-Policy` header. | "same-origin" |
| CrossOriginResourcePolicy | `string` | Value for the `Cross-Origin-Resource-Policy` header. | "same-origin" |
| OriginAgentCluster | `string` | Value for the `Origin-Agent-Cluster` header. | "?1" |
| XDNSPrefetchControl | `string` | Value for the `X-DNS-Prefetch-Control` header. | "off" |
| XDownloadOptions | `string` | Value for the `X-Download-Options` header. | "noopen" |
| XPermittedCrossDomain | `string` | Value for the `X-Permitted-Cross-Domain-Policies` header. | "none" |
## Default Config
```go
var ConfigDefault = Config{
XSSProtection: "0",
ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN",
ReferrerPolicy: "no-referrer",
CrossOriginEmbedderPolicy: "require-corp",
CrossOriginOpenerPolicy: "same-origin",
CrossOriginResourcePolicy: "same-origin",
OriginAgentCluster: "?1",
XDNSPrefetchControl: "off",
XDownloadOptions: "noopen",
XPermittedCrossDomain: "none",
}
```
================================================
FILE: docs/middleware/idempotency.md
================================================
---
id: idempotency
---
# Idempotency
The Idempotency middleware helps build fault-tolerant APIs. Duplicate requests—such as retries after network issues—won't trigger the same action twice on the server.
Refer to [IETF RFC 7231 §4.2.2](https://tools.ietf.org/html/rfc7231#section-4.2.2) for definitions of safe and idempotent HTTP methods.
## HTTP Method Categories
* **Safe Methods** (do not modify server state): `GET`, `HEAD`, `OPTIONS`, `TRACE`
* **Idempotent Methods** (identical requests have the same effect as a single one): all safe methods **plus** `PUT` and `DELETE`
> According to the RFC, safe methods never change server state, while idempotent methods may change state but remain safe to repeat.
## Signatures
```go
func New(config ...Config) fiber.Handler
func IsFromCache(c fiber.Ctx) bool
func WasPutToCache(c fiber.Ctx) bool
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/idempotency"
)
```
Once your Fiber app is initialized, configure the middleware:
### Default Config (Skip **Safe** Methods)
By default, the `Next` function skips middleware for safe methods only:
```go
app.Use(idempotency.New())
```
### Skip **Idempotent** Methods Instead
Skip all idempotent methods (including `PUT` and `DELETE`) by overriding `Next`:
```go
app.Use(idempotency.New(idempotency.Config{
Next: func(c fiber.Ctx) bool {
// Skip middleware for idempotent methods (safe + PUT, DELETE)
return fiber.IsMethodIdempotent(c.Method())
},
}))
```
### Custom Config
```go
app.Use(idempotency.New(idempotency.Config{
Lifetime: 42 * time.Minute,
// ...
}))
```
## Config
Idempotency keys are hidden in logs and error messages by default. Set `DisableValueRedaction` to `true` only when you need to expose them for debugging.
| Property | Type | Description | Default |
|:--------------------|:-----------------------|:----------------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------|
| Next | `func(fiber.Ctx) bool` | Function to skip this middleware when it returns `true`; use `IsMethodSafe` or `IsMethodIdempotent`. | `func(c fiber.Ctx) bool { return fiber.IsMethodSafe(c.Method()) }` |
| Lifetime | `time.Duration` | Maximum lifetime of an idempotency key. | `30 * time.Minute` |
| KeyHeader | `string` | Header name containing the idempotency key. | `"X-Idempotency-Key"` |
| KeyHeaderValidate | `func(string) error` | Function to validate idempotency header syntax (e.g., UUID). | UUID length check (`36` characters) |
| KeepResponseHeaders | `[]string` | List of headers to preserve from original response. | `nil` (keep all headers) |
| DisableValueRedaction | `bool` | Disables idempotency key redaction in logs and error messages. | `false` |
| Lock | `Locker` | Locks an idempotency key to prevent race conditions. | In-memory locker |
| Storage | `fiber.Storage` | Stores response data by idempotency key. | In-memory storage |
## Default Config Values
```go
var ConfigDefault = Config{
Next: func(c fiber.Ctx) bool {
// Skip middleware for safe methods per RFC 7231 §4.2.2
return fiber.IsMethodSafe(c.Method())
},
Lifetime: 30 * time.Minute,
KeyHeader: "X-Idempotency-Key",
KeyHeaderValidate: func(k string) error {
if l, wl := len(k), 36; l != wl { // UUID length is 36 chars
return fmt.Errorf("%w: invalid length: %d != %d", ErrInvalidIdempotencyKey, l, wl)
}
return nil
},
KeepResponseHeaders: nil,
Lock: nil, // Set in configDefault so we don't allocate data here.
Storage: nil, // Set in configDefault so we don't allocate data here.
DisableValueRedaction: false,
}
```
================================================
FILE: docs/middleware/keyauth.md
================================================
---
id: keyauth
---
# KeyAuth
The KeyAuth middleware implements API key authentication.
## Signatures
```go
func New(config ...Config) fiber.Handler
func TokenFromContext(ctx any) string
```
`TokenFromContext` accepts a `fiber.CustomCtx`, `fiber.Ctx`, a `*fasthttp.RequestCtx`, or a `context.Context`.
## Examples
### Basic example
This example registers KeyAuth with an API key stored in a cookie.
```go
package main
import (
"crypto/sha256"
"crypto/subtle"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/keyauth"
)
var (
apiKey = "correct horse battery staple"
)
func validateAPIKey(c fiber.Ctx, key string) (bool, error) {
hashedAPIKey := sha256.Sum256([]byte(apiKey))
hashedKey := sha256.Sum256([]byte(key))
if subtle.ConstantTimeCompare(hashedAPIKey[:], hashedKey[:]) == 1 {
return true, nil
}
return false, keyauth.ErrMissingOrMalformedAPIKey
}
func main() {
app := fiber.New()
// Register middleware before the routes that need it
app.Use(keyauth.New(keyauth.Config{
Extractor: keyauth.FromCookie("access_token"),
Validator: validateAPIKey,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Successfully authenticated!")
})
app.Listen(":3000")
}
```
**Test:**
```bash
# No API key specified -> 401 Missing or invalid API Key
curl http://localhost:3000
#> Missing or invalid API Key
# Correct API key -> 200 OK
curl --cookie "access_token=correct horse battery staple" http://localhost:3000
#> Successfully authenticated!
# Incorrect API key -> 401 Missing or invalid API Key
curl --cookie "access_token=Clearly A Wrong Key" http://localhost:3000
#> Missing or invalid API Key
```
For a more detailed example, see the [`fiber-envoy-extauthz`](https://github.com/gofiber/recipes/tree/master/fiber-envoy-extauthz) recipe in the `gofiber/recipes` repository.
### Authenticate only certain endpoints
Use the `Next` function to run KeyAuth only on selected routes.
```go
package main
import (
"crypto/sha256"
"crypto/subtle"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/keyauth"
"regexp"
"strings"
)
var (
apiKey = "correct horse battery staple"
protectedURLs = []*regexp.Regexp{
regexp.MustCompile("^/authenticated$"),
regexp.MustCompile("^/auth2$"),
}
)
func validateAPIKey(c fiber.Ctx, key string) (bool, error) {
hashedAPIKey := sha256.Sum256([]byte(apiKey))
hashedKey := sha256.Sum256([]byte(key))
if subtle.ConstantTimeCompare(hashedAPIKey[:], hashedKey[:]) == 1 {
return true, nil
}
return false, keyauth.ErrMissingOrMalformedAPIKey
}
func authFilter(c fiber.Ctx) bool {
originalURL := strings.ToLower(c.OriginalURL())
for _, pattern := range protectedURLs {
if pattern.MatchString(originalURL) {
// Run middleware for protected routes
return false
}
}
// Skip middleware for non-protected routes
return true
}
func main() {
app := fiber.New()
app.Use(keyauth.New(keyauth.Config{
Next: authFilter,
Extractor: keyauth.FromCookie("access_token"),
Validator: validateAPIKey,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Welcome")
})
app.Get("/authenticated", func(c fiber.Ctx) error {
return c.SendString("Successfully authenticated!")
})
app.Get("/auth2", func(c fiber.Ctx) error {
return c.SendString("Successfully authenticated 2!")
})
app.Listen(":3000")
}
```
**Test:**
```bash
# / doesn't require authentication
curl http://localhost:3000
#> Welcome
# /authenticated requires authentication
curl --cookie "access_token=correct horse battery staple" http://localhost:3000/authenticated
#> Successfully authenticated!
# /auth2 requires authentication too
curl --cookie "access_token=correct horse battery staple" http://localhost:3000/auth2
#> Successfully authenticated 2!
```
### Apply middleware in the handler
You can apply the middleware to specific routes or groups instead of globally. This example uses the default extractor (`FromAuthHeader`).
```go
package main
import (
"crypto/sha256"
"crypto/subtle"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/keyauth"
)
const (
apiKey = "my-super-secret-key"
)
func main() {
app := fiber.New()
authMiddleware := keyauth.New(keyauth.Config{
Validator: func(c fiber.Ctx, key string) (bool, error) {
hashedAPIKey := sha256.Sum256([]byte(apiKey))
hashedKey := sha256.Sum256([]byte(key))
if subtle.ConstantTimeCompare(hashedAPIKey[:], hashedKey[:]) == 1 {
return true, nil
}
return false, keyauth.ErrMissingOrMalformedAPIKey
},
})
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Welcome")
})
app.Get("/allowed", authMiddleware, func(c fiber.Ctx) error {
return c.SendString("Successfully authenticated!")
})
app.Listen(":3000")
}
```
**Test:**
```bash
# / doesn't require authentication
curl http://localhost:3000
#> Welcome
# /allowed requires authentication
curl --header "Authorization: Bearer my-super-secret-key" http://localhost:3000/allowed
#> Successfully authenticated!
```
## Key Extractors
KeyAuth uses an `Extractor` from the shared [extractors](../guide/extractors) package to retrieve the API key from the request. You can specify one or more extractors in the configuration. For a full list of extractors, chaining, and advanced usage, see the [Extractors Guide](../guide/extractors).
### Typical Usage
Specify the extractor in the config. For example, to extract from a cookie:
```go
app.Use(keyauth.New(keyauth.Config{
Extractor: extractors.FromCookie("access_token"),
Validator: validateAPIKey,
}))
```
To use the default (Authorization header with Bearer scheme):
```go
app.Use(keyauth.New(keyauth.Config{
Validator: validateAPIKey, // Extractor defaults to FromAuthHeader("Bearer")
}))
```
To try multiple sources (header, then query):
```go
app.Use(keyauth.New(keyauth.Config{
Extractor: extractors.Chain(
extractors.FromHeader("X-API-Key"),
extractors.FromQuery("api_key"),
),
Validator: validateAPIKey,
}))
```
For custom logic, use `extractors.FromCustom`:
```go
app.Use(keyauth.New(keyauth.Config{
Extractor: extractors.FromCustom(func(c fiber.Ctx) (string, error) {
return c.Get("X-My-API-Key"), nil
}),
Validator: validateAPIKey,
}))
```
Refer to the [Extractors Guide](../guide/extractors) for details, security notes, and advanced configuration.
## Config
| Property | Type | Description | Default |
|:----------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` |
| SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `c.Next()` |
| ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. By default a 401 response with a `WWW-Authenticate` challenge is sent. | Default error handler |
| Validator | `func(fiber.Ctx, string) (bool, error)` | **Required.** Validator is a function to validate the key. | `nil` (panic) |
| Extractor | `extractors.Extractor` | Extractor defines how to retrieve the key from the request. Use helper functions from the shared extractors package, e.g. `extractors.FromAuthHeader("Bearer")` or `extractors.FromCookie("access_token")`. | `extractors.FromAuthHeader("Bearer")` |
| Realm | `string` | Realm specifies the protected area name used in the `WWW-Authenticate` header. | `"Restricted"` |
| Challenge | `string` | Value of the `WWW-Authenticate` header when no `Authorization` scheme is present. | `ApiKey realm="Restricted"` |
| Error | `string` | Error code appended as the `error` parameter in Bearer challenges. Must be `invalid_request`, `invalid_token`, or `insufficient_scope`. | `""` |
| ErrorDescription| `string` | Human-readable text for the `error_description` parameter in Bearer challenges. Requires `Error`. | `""` |
| ErrorURI | `string` | URI identifying a human-readable web page with information about the `error` in Bearer challenges. Requires `Error` and must be an absolute URI. | `""` |
| Scope | `string` | Space-delimited list of scopes for the `scope` parameter in Bearer challenges. Each token must conform to the RFC 6750 `scope-token` syntax and requires `Error` set to `insufficient_scope`. | `""` |
## Default Config
```go
var ConfigDefault = Config{
SuccessHandler: func(c fiber.Ctx) error {
return c.Next()
},
ErrorHandler: func(c fiber.Ctx, _ error) error {
return c.Status(fiber.StatusUnauthorized).SendString(ErrMissingOrMalformedAPIKey.Error())
},
Realm: "Restricted",
Extractor: extractors.FromAuthHeader("Bearer"),
}
```
================================================
FILE: docs/middleware/limiter.md
================================================
---
id: limiter
---
# Limiter
The Limiter middleware for [Fiber](https://github.com/gofiber/fiber) throttles repeated requests to public APIs or endpoints such as password resets. It's also useful for API clients, web crawlers, or other tasks that need rate limiting.
Limiter redacts request keys in error paths by default so storage identifiers and rate-limit keys don't leak into logs. Set `DisableValueRedaction` to `true` when you explicitly need the raw key for troubleshooting.
:::note
This middleware uses our [Storage](https://github.com/gofiber/storage) package to support various databases through a single interface. The default configuration for this middleware saves data to memory, see the examples below for other databases.
:::
:::note
This module does not share state with other processes/servers by default.
:::
## Signatures
```go
func New(config ...Config) fiber.Handler
type Handler interface {
New(config *Config) fiber.Handler
}
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/limiter"
)
```
Once your Fiber app is initialized, use the middleware like this:
```go
// Initialize default config
app.Use(limiter.New())
// Or extend your config for customization
app.Use(limiter.New(limiter.Config{
Next: func(c fiber.Ctx) bool {
return c.IP() == "127.0.0.1"
},
Max: 20,
MaxFunc: func(c fiber.Ctx) int {
return 20
},
Expiration: 30 * time.Second,
ExpirationFunc: func(c fiber.Ctx) time.Duration {
// Use longer expiration for sensitive endpoints
if c.Path() == "/login" {
return 60 * time.Second
}
return 30 * time.Second
},
KeyGenerator: func(c fiber.Ctx) string {
return c.Get("x-forwarded-for")
},
LimitReached: func(c fiber.Ctx) error {
return c.SendFile("./toofast.html")
},
Storage: myCustomStorage{},
}))
```
## Sliding window
Instead of using the standard fixed window algorithm, you can enable the [sliding window](https://en.wikipedia.org/wiki/Sliding_window_protocol) algorithm.
An example configuration is:
```go
app.Use(limiter.New(limiter.Config{
Max: 20,
Expiration: 30 * time.Second,
LimiterMiddleware: limiter.SlidingWindow{},
}))
```
Each new window also considers the previous one (if any). The rate is calculated as:
```text
weightOfPreviousWindow = previousWindowRequests * (elapsedInCurrentWindow / Expiration)
rate = weightOfPreviousWindow + currentWindowRequests
```
## Dynamic limit
You can also calculate the limit dynamically using the `MaxFunc` parameter. It receives the request context and allows you to compute a different limit for each request.
Example:
```go
app.Use(limiter.New(limiter.Config{
MaxFunc: func(c fiber.Ctx) int {
return getUserLimit(ctx.Param("id"))
},
Expiration: 30 * time.Second,
}))
```
## Dynamic expiration
You can also calculate the expiration dynamically using the `ExpirationFunc` parameter. It receives the request context and allows you to set a different expiration window for each request.
Example:
```go
app.Use(limiter.New(limiter.Config{
Max: 20,
ExpirationFunc: func(c fiber.Ctx) time.Duration {
return getExpirationForRoute(c.Path())
},
}))
```
## Config
| Property | Type | Description | Default |
|:-----------------------|:--------------------------|:--------------------------------------------------------------------------------------------|:-----------------------------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` |
| Max | `int` | Maximum number of recent connections within `Expiration` seconds before sending a 429 response. | 5 |
| MaxFunc | `func(fiber.Ctx) int` | Function that calculates the maximum number of recent connections within `Expiration` seconds before sending a 429 response. | A function that returns `cfg.Max` |
| KeyGenerator | `func(fiber.Ctx) string` | Function to generate custom keys; uses `c.IP()` by default. | A function using `c.IP()` as the default |
| Expiration | `time.Duration` | Duration to keep request records in memory. | 1 * time.Minute |
| ExpirationFunc | `func(fiber.Ctx) time.Duration` | Function that calculates the expiration duration dynamically. | A function that returns `cfg.Expiration` |
| LimitReached | `fiber.Handler` | Called when a request exceeds the limit. | A function sending a 429 response |
| SkipFailedRequests | `bool` | When set to `true`, requests with status code ≥ 400 aren't counted. | false |
| SkipSuccessfulRequests | `bool` | When set to `true`, requests with status code < 400 aren't counted. | false |
| DisableHeaders | `bool` | When set to `true`, the middleware omits rate limit headers (`X-RateLimit-*` and `Retry-After`). | false |
| DisableValueRedaction | `bool` | Disables redaction of limiter keys in error messages and logs. | false |
| Storage | `fiber.Storage` | Persists middleware state. | An in-memory store for this process only |
| LimiterMiddleware | `limiter.Handler` | Selects the algorithm implementation. Implementations now receive a pointer to the active config when their `New` method is invoked. | A new Fixed Window Rate Limiter |
:::note
A custom store can be used if it implements the `Storage` interface - more details and an example can be found in `store.go`.
:::
## Default Config
```go
var ConfigDefault = Config{
Max: 5,
MaxFunc: func(c fiber.Ctx) int {
return 5
},
Expiration: 1 * time.Minute,
// ExpirationFunc defaults to nil and is set dynamically to return cfg.Expiration
KeyGenerator: func(c fiber.Ctx) string {
return c.IP()
},
LimitReached: func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTooManyRequests)
},
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
DisableHeaders: false,
DisableValueRedaction: false,
LimiterMiddleware: FixedWindow{},
}
```
### Custom Storage/Database
You can use any storage from our [storage](https://github.com/gofiber/storage/) package.
```go
storage := sqlite3.New() // From github.com/gofiber/storage/sqlite3/v2
app.Use(limiter.New(limiter.Config{
Storage: storage,
}))
```
================================================
FILE: docs/middleware/logger.md
================================================
---
id: logger
---
# Logger
Logger middleware for [Fiber](https://github.com/gofiber/fiber) that logs HTTP requests and responses.
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
Import the package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/logger"
)
```
:::tip
Registration order matters: only routes added after the logger are logged, so register it early.
:::
Once your Fiber app is initialized, use the middleware like this:
```go
// Initialize default config
app.Use(logger.New())
// Or extend your config for customization
// Log remote IP and port
app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
}))
// Logging Request ID
app.Use(requestid.New()) // Ensure requestid middleware is used before the logger
app.Use(logger.New(logger.Config{
CustomTags: map[string]logger.LogFunc{
"requestid": func(output logger.Buffer, c fiber.Ctx, data *logger.Data, extraParam string) (int, error) {
return output.WriteString(requestid.FromContext(c))
},
},
// For more options, see the Config section
// Use the custom tag ${requestid} as defined above.
Format: "${pid} ${requestid} ${status} - ${method} ${path}\n",
}))
// Changing TimeZone & TimeFormat
app.Use(logger.New(logger.Config{
Format: "${pid} ${status} - ${method} ${path}\n",
TimeFormat: "02-Jan-2006",
TimeZone: "America/New_York",
}))
// Custom File Writer
accessLog, err := os.OpenFile("./access.log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
log.Fatalf("error opening access.log file: %v", err)
}
defer accessLog.Close()
app.Use(logger.New(logger.Config{
Stream: accessLog,
}))
// Add Custom Tags
app.Use(logger.New(logger.Config{
CustomTags: map[string]logger.LogFunc{
"custom_tag": func(output logger.Buffer, c fiber.Ctx, data *logger.Data, extraParam string) (int, error) {
return output.WriteString("it is a custom tag")
},
},
}))
// Callback after log is written
app.Use(logger.New(logger.Config{
TimeFormat: time.RFC3339Nano,
TimeZone: "Asia/Shanghai",
Done: func(c fiber.Ctx, logString []byte) {
if c.Response().StatusCode() != fiber.StatusOK {
reporter.SendToSlack(logString)
}
},
}))
// Disable colors when outputting to default format
app.Use(logger.New(logger.Config{
DisableColors: true,
}))
// Force the use of colors
app.Use(logger.New(logger.Config{
ForceColors: true,
}))
// Use predefined formats
app.Use(logger.New(logger.Config{
Format: logger.CommonFormat,
}))
app.Use(logger.New(logger.Config{
Format: logger.CombinedFormat,
}))
app.Use(logger.New(logger.Config{
Format: logger.JSONFormat,
}))
app.Use(logger.New(logger.Config{
Format: logger.ECSFormat,
}))
```
### Use Logger Middleware with Other Loggers
To combine the logger middleware with loggers like Zerolog, Zap, or Logrus, use the `LoggerToWriter` helper to adapt them to an `io.Writer`.
```go
package main
import (
"github.com/gofiber/contrib/fiberzap/v2"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/fiber/v3/middleware/logger"
)
func main() {
// Create a new Fiber instance
app := fiber.New()
// Create a new zap logger which is compatible with Fiber AllLogger interface
zap := fiberzap.NewLogger(fiberzap.LoggerConfig{
ExtraKeys: []string{"request_id"},
})
// Use the logger middleware with the zap logger
app.Use(logger.New(logger.Config{
Stream: logger.LoggerToWriter(zap, log.LevelDebug),
}))
// Define a route
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
// Start server on http://localhost:3000
app.Listen(":3000")
}
```
:::tip
Writing to `os.File` is goroutine-safe, but custom streams may require locking to serialize writes.
:::
## Config
| Property | Type | Description | Default |
| :------------ | :------------------------------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------- |
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` |
| Skip | `func(fiber.Ctx) bool` | Skip is a function to determine if logging is skipped or written to Stream. | `nil` |
| Done | `func(fiber.Ctx, []byte)` | Done is a function that is called after the log string for a request is written to Stream, and pass the log string as parameter. | `nil` |
| CustomTags | `map[string]LogFunc` | tagFunctions defines the custom tag action. | `map[string]LogFunc` |
| `Format` | `string` | Defines the logging tags. See more in [Predefined Formats](#predefined-formats), or create your own using [Tags](#constants). | `[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}\n` (same as `DefaultFormat`) |
| TimeFormat | `string` | TimeFormat defines the time format for log timestamps. | `15:04:05` |
| TimeZone | `string` | TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc | `"Local"` |
| TimeInterval | `time.Duration` | TimeInterval is the delay before the timestamp is updated. | `500 * time.Millisecond` |
| Stream | `io.Writer` | Stream is a writer where logs are written. | `os.Stdout` |
| LoggerFunc | `func(c fiber.Ctx, data *Data, cfg *Config) error` | Custom logger function for integration with logging libraries (Zerolog, Zap, Logrus, etc). Defaults to Fiber's default logger if not defined. | `see default_logger.go defaultLoggerInstance` |
| DisableColors | `bool` | DisableColors defines if the logs output should be colorized. | `false` |
| ForceColors | `bool` | ForceColors defines if the logs output should be colorized even when the output is not a terminal. | `false` |
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
Skip: nil,
Done: nil,
Format: DefaultFormat,
TimeFormat: "15:04:05",
TimeZone: "Local",
TimeInterval: 500 * time.Millisecond,
Stream: os.Stdout,
BeforeHandlerFunc: beforeHandlerFunc,
LoggerFunc: defaultLoggerInstance,
enableColors: true,
}
```
## Predefined Formats
Logger provides predefined formats that you can use by name or directly by specifying the format string.
| **Format Constant** | **Format String** | **Description** |
|---------------------|--------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
| `DefaultFormat` | `"[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}\n"` | Fiber's default logger format. |
| `CommonFormat` | `"${ip} - - [${time}] "${method} ${url} ${protocol}" ${status} ${bytesSent}\n"` | Common Log Format (CLF) used in web server logs. |
| `CombinedFormat` | `"${ip} - - [${time}] "${method} ${url} ${protocol}" ${status} ${bytesSent} "${referer}" "${ua}"\n"` | CLF format plus the `referer` and `user agent` fields. |
| `JSONFormat` | `"{time: ${time}, ip: ${ip}, method: ${method}, url: ${url}, status: ${status}, bytesSent: ${bytesSent}}\n"` | JSON format for structured logging. |
| `ECSFormat` | `"{\"@timestamp\":\"${time}\",\"ecs\":{\"version\":\"1.6.0\"},\"client\":{\"ip\":\"${ip}\"},\"http\":{\"request\":{\"method\":\"${method}\",\"url\":\"${url}\",\"protocol\":\"${protocol}\"},\"response\":{\"status_code\":${status},\"body\":{\"bytes\":${bytesSent}}}},\"log\":{\"level\":\"INFO\",\"logger\":\"fiber\"},\"message\":\"${method} ${url} responded with ${status}\"}\n"` | Elastic Common Schema (ECS) format for structured logging. |
## Constants
```go
// Logger variables
const (
TagPid = "pid"
TagTime = "time"
TagReferer = "referer"
TagProtocol = "protocol"
TagPort = "port"
TagIP = "ip"
TagIPs = "ips"
TagHost = "host"
TagMethod = "method"
TagPath = "path"
TagURL = "url"
TagUA = "ua"
TagLatency = "latency"
TagStatus = "status" // response status
TagResBody = "resBody" // response body
TagReqHeaders = "reqHeaders"
TagQueryStringParams = "queryParams" // request query parameters
TagBody = "body" // request body
TagBytesSent = "bytesSent"
TagBytesReceived = "bytesReceived"
TagRoute = "route"
TagError = "error"
TagReqHeader = "reqHeader:" // request header
TagRespHeader = "respHeader:" // response header
TagQuery = "query:" // request query
TagForm = "form:" // request form
TagCookie = "cookie:" // request cookie
TagLocals = "locals:"
// colors
TagBlack = "black"
TagRed = "red"
TagGreen = "green"
TagYellow = "yellow"
TagBlue = "blue"
TagMagenta = "magenta"
TagCyan = "cyan"
TagWhite = "white"
TagReset = "reset"
)
```
================================================
FILE: docs/middleware/paginate.md
================================================
---
id: paginate
---
# Paginate
Pagination middleware for [Fiber](https://github.com/gofiber/fiber) that extracts pagination parameters from query strings and stores them in the request context. Supports page-based, offset-based, and cursor-based pagination with multi-field sorting.
## Signatures
```go
func New(config ...Config) fiber.Handler
func FromContext(ctx any) (*PageInfo, bool)
```
`FromContext` accepts `fiber.CustomCtx`, `fiber.Ctx`, `*fasthttp.RequestCtx`, or `context.Context`.
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/paginate"
)
```
Once your Fiber app is initialized, choose one of the following approaches:
### Basic Usage
```go
app.Use(paginate.New())
app.Get("/users", func(c fiber.Ctx) error {
pageInfo, ok := paginate.FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
// Use pageInfo.Page, pageInfo.Limit, pageInfo.Start()
// GET /users?page=2&limit=20 → Page: 2, Limit: 20, Start(): 20
return c.JSON(pageInfo)
})
```
### Sorting
```go
app.Use(paginate.New(paginate.Config{
SortKey: "sort",
DefaultSort: "id",
AllowedSorts: []string{"id", "name", "created_at"},
}))
// GET /users?sort=name,-created_at
// → Sort: [{Field: "name", Order: "asc"}, {Field: "created_at", Order: "desc"}]
```
### Cursor Pagination
```go
app.Use(paginate.New())
app.Get("/feed", func(c fiber.Ctx) error {
pageInfo, ok := paginate.FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
if pageInfo.Cursor != "" {
// Decode the cursor to get keyset values
values := pageInfo.CursorValues()
// Use values["id"], values["created_at"], etc. for WHERE clause
}
// results is a slice of items from your database query
// After fetching results, set the next cursor for the client
if len(results) > 0 {
lastItem := results[len(results)-1]
if err := pageInfo.SetNextCursor(map[string]any{
"id": lastItem.ID,
"created_at": lastItem.CreatedAt,
}); err != nil {
return err
}
}
return c.JSON(fiber.Map{
"data": results,
"has_more": pageInfo.HasMore,
"next_cursor": pageInfo.NextCursor,
})
})
// First request: GET /feed?limit=20
// Next request: GET /feed?cursor=&limit=20
```
### Custom Configuration
```go
app.Use(paginate.New(paginate.Config{
PageKey: "p",
LimitKey: "size",
DefaultPage: 1,
DefaultLimit: 25,
SortKey: "order_by",
DefaultSort: "created_at",
AllowedSorts: []string{"created_at", "name", "email"},
CursorKey: "after",
CursorParam: "starting_after",
}))
```
## Config
| Property | Type | Description | Default |
|:-------------|:-------------------------|:-------------------------------------------------------------------|:-----------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` |
| PageKey | `string` | Query string key for page number. | `"page"` |
| DefaultPage | `int` | Default page number. | `1` |
| LimitKey | `string` | Query string key for limit. | `"limit"` |
| DefaultLimit | `int` | Default items per page. | `10` |
| SortKey | `string` | Query string key for sort. | `""` |
| DefaultSort | `string` | Default sort field. | `"id"` |
| AllowedSorts | `[]string` | Whitelist of allowed sort fields. If nil, all fields are allowed. | `nil` |
| OffsetKey | `string` | Query string key for offset. | `"offset"` |
| CursorKey | `string` | Query string key for cursor-based pagination. | `"cursor"` |
| CursorParam | `string` | Optional alias for the cursor query key. | `""` |
| MaxLimit | `int` | Maximum items per page. | `100` |
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
PageKey: "page",
DefaultPage: 1,
LimitKey: "limit",
DefaultLimit: 10,
MaxLimit: 100,
DefaultSort: "id",
OffsetKey: "offset",
CursorKey: "cursor",
}
```
## PageInfo
The `PageInfo` struct is stored in the request context and provides:
| Method | Description |
|:------------------------------------------------|:---------------------------------------------------------------|
| `Start() int` | Returns calculated start index (from page/limit or offset) |
| `SortBy(field, order)` | Adds a sort field (chainable) |
| `NextPageURL(baseURL)` | Generates next page URL with default keys |
| `NextPageURLWithKeys(baseURL, pageKey, limitKey)` | Generates next page URL with custom query keys |
| `PreviousPageURL(baseURL)` | Generates previous page URL (empty on page 1) |
| `PreviousPageURLWithKeys(baseURL, pageKey, limitKey)` | Generates previous page URL with custom query keys |
| `NextCursorURL(baseURL)` | Generates next cursor URL (empty if no more) |
| `NextCursorURLWithKeys(baseURL, cursorKey, limitKey)` | Generates next cursor URL with custom query keys |
| `CursorValues()` | Decodes cursor token into key-value map |
| `SetNextCursor(values)` | Encodes values into cursor token, sets HasMore; returns error |
## Safety
- Limit is capped at `MaxLimit` (default: 100, configurable) to prevent excessive memory usage
- Page values below 1 reset to 1
- Negative offsets reset to 0
- Sort fields are validated against `AllowedSorts`
- Cursor tokens exceeding 2048 characters are rejected with `400 Bad Request`
- `SetNextCursor` returns an error if the encoded token would exceed 2048 characters, preventing the server from issuing cursors it would later reject
- Invalid cursor tokens return `400 Bad Request` via Fiber's error handler
- If `DefaultSort` is not included in `AllowedSorts`, it falls back to the first allowed sort field
- URL helpers preserve existing query parameters when building pagination links
================================================
FILE: docs/middleware/pprof.md
================================================
---
id: pprof
---
# Pprof
Pprof middleware exposes runtime profiling data for analysis with the Go `pprof` tool. Importing it registers handlers under `/debug/pprof/`.
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/pprof"
)
```
Once your Fiber app is initialized, use the middleware as shown:
```go
// Initialize default config
app.Use(pprof.New())
// Or customize the config
// For multi-ingress systems, add a URL prefix:
app.Use(pprof.New(pprof.Config{Prefix: "/endpoint-prefix"}))
// The resulting URL is "/endpoint-prefix/debug/pprof/"
```
## Config
| Property | Type | Description | Default |
|:---------|:------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-------:|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` |
| Prefix | `string` | Prefix adds a segment before `/debug/pprof`; it must start with a slash and omit the trailing slash. Example: `/federated-fiber` | `""` |
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
}
```
================================================
FILE: docs/middleware/proxy.md
================================================
---
id: proxy
---
# Proxy
The Proxy middleware forwards requests to one or more upstream servers.
## Signatures
```go
// Balancer creates a load balancer among multiple upstream servers.
func Balancer(config ...Config) fiber.Handler
// Forward performs the given http request and fills the given http response.
func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler
// Do performs the given http request and fills the given http response.
func Do(c fiber.Ctx, addr string, clients ...*fasthttp.Client) error
// DoRedirects performs the given http request and fills the given http response while following up to maxRedirectsCount redirects.
func DoRedirects(c fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error
// DoDeadline performs the given request and waits for response until the given deadline.
func DoDeadline(c fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error
// DoTimeout performs the given request and waits for response during the given timeout duration.
func DoTimeout(c fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error
// DomainForward performs the given http request based on the provided domain and fills the given http response.
func DomainForward(hostname string, addr string, clients ...*fasthttp.Client) fiber.Handler
// BalancerForward performs the given http request based round robin balancer and fills the given http response.
func BalancerForward(servers []string, clients ...*fasthttp.Client) fiber.Handler
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/proxy"
)
```
Once your Fiber app is initialized, you can use the middleware as shown:
```go
// Use proxy.WithClient to set a global custom client.
proxy.WithClient(&fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
// Allow self-signed certificates when proxying to HTTPS targets.
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
})
// Forward requests for a specific domain with proxy.DomainForward.
app.Get("/payments", proxy.DomainForward("docs.gofiber.io", "http://localhost:8000"))
// Forward to a URL using a custom client
app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif", &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
}))
// Make a proxied request within a handler
app.Get("/:id", func(c fiber.Ctx) error {
url := "https://i.imgur.com/" + c.Params("id") + ".gif"
if err := proxy.Do(c, url); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})
// Proxy requests while following redirects
app.Get("/proxy", func(c fiber.Ctx) error {
if err := proxy.DoRedirects(c, "http://google.com", 3); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})
// Proxy requests and wait up to five seconds before timing out
app.Get("/proxy", func(c fiber.Ctx) error {
if err := proxy.DoTimeout(c, "http://localhost:3000", time.Second * 5); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})
// Proxy requests with a deadline one minute from now
app.Get("/proxy", func(c fiber.Ctx) error {
if err := proxy.DoDeadline(c, "http://localhost", time.Now().Add(time.Minute)); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})
// Minimal round-robin balancer
app.Use(proxy.Balancer(proxy.Config{
Servers: []string{
"http://localhost:3001",
"http://localhost:3002",
"http://localhost:3003",
},
}))
// Keep the Connection header when proxying
app.Use(proxy.Balancer(proxy.Config{
Servers: []string{
"http://localhost:3001",
},
KeepConnectionHeader: true,
}))
// Or extend your balancer for customization
app.Use(proxy.Balancer(proxy.Config{
Servers: []string{
"http://localhost:3001",
"http://localhost:3002",
"http://localhost:3003",
},
ModifyRequest: func(c fiber.Ctx) error {
c.Request().Header.Add("X-Real-IP", c.IP())
return nil
},
ModifyResponse: func(c fiber.Ctx) error {
c.Response().Header.Del(fiber.HeaderServer)
return nil
},
}))
// Or this way if the balancer is using https and the destination server is only using http.
app.Use(proxy.BalancerForward([]string{
"http://localhost:3001",
"http://localhost:3002",
"http://localhost:3003",
}))
// Make round robin balancer with IPv6 support.
app.Use(proxy.Balancer(proxy.Config{
Servers: []string{
"http://[::1]:3001",
"http://127.0.0.1:3002",
"http://localhost:3003",
},
// Enable TCP4 and TCP6 network stacks.
DialDualStack: true,
}))
```
## Config
| Property | Type | Description | Default |
|:----------------|:-----------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` |
| Servers | `[]string` | Servers defines a list of `://` HTTP servers, which are used in a round-robin manner. i.e.: "[https://foobar.com](https://foobar.com), [http://www.foobar.com](http://www.foobar.com)" | (Required) |
| ModifyRequest | `fiber.Handler` | ModifyRequest allows you to alter the request. | `nil` |
| ModifyResponse | `fiber.Handler` | ModifyResponse allows you to alter the response. | `nil` |
| Timeout | `time.Duration` | Timeout is the request timeout used when calling the proxy client. | 1 second |
| ReadBufferSize | `int` | Per-connection buffer size for requests' reading. This also limits the maximum header size. Increase this buffer if your clients send multi-KB RequestURIs and/or multi-KB headers (for example, BIG cookies). | (Not specified) |
| WriteBufferSize | `int` | Per-connection buffer size for responses' writing. | (Not specified) |
| KeepConnectionHeader | `bool` | Keeps the `Connection` header when set to `true`. By default the header is removed to comply with RFC 7230 §6.1 and avoid proxy loops. | `false` |
| TLSConfig | `*tls.Config` | TLS config for the HTTP client. | `nil` |
| DialDualStack | `bool` | Client will attempt to connect to both IPv4 and IPv6 host addresses if set to true. | `false` |
| Client | `*fasthttp.LBClient` | Client is a custom client when client config is complex. | `nil` |
## Default Config
```go
var ConfigDefault = Config{
Next: nil,
ModifyRequest: nil,
ModifyResponse: nil,
Timeout: fasthttp.DefaultLBClientTimeout,
KeepConnectionHeader: false,
}
```
================================================
FILE: docs/middleware/recover.md
================================================
---
id: recover
---
# Recover
The Recover middleware for [Fiber](https://github.com/gofiber/fiber) intercepts panics and forwards them to the central [ErrorHandler](../guide/error-handling).
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
recoverer "github.com/gofiber/fiber/v3/middleware/recover"
)
```
Once your Fiber app is initialized, use the middleware like this:
```go
// Initialize default config
app.Use(recoverer.New())
// Panics in subsequent handlers are caught by the middleware
app.Get("/", func(c fiber.Ctx) error {
panic("I'm an error")
})
```
## Config
| Property | Type | Description | Default |
|:------------------|:-----------------------------|:------------------------------------------------------|:---------------------------|
| Next | `func(fiber.Ctx) bool` | Skip when the function returns `true`. | `nil` |
| PanicHandler | `func(fiber.Ctx, any) error` | Customize the error returned from a recovered panic. | `DefaultPanicHandler` |
| EnableStackTrace | `bool` | Capture and include a stack trace in error responses. | `false` |
| StackTraceHandler | `func(fiber.Ctx, any)` | Handle the captured stack trace when enabled. | `defaultStackTraceHandler` |
## Default Config
```go
var ConfigDefault = recoverer.Config{
Next: nil,
PanicHandler: DefaultPanicHandler,
StackTraceHandler: defaultStackTraceHandler,
EnableStackTrace: false,
}
// Set up a PanicHandler to hide internals.
app.Use(recoverer.New(recoverer.Config{PanicHandler: func(c fiber.Ctx, r any) error {
return fiber.ErrInternalServerError
}}))
// In more elaborate scenarios you can also create a custom error which can be processed differently in the fiber.ErrorHandler.
// See the tests for an example of such an ErrorHandler.
// You could also just wrap the default handler's error, e.g. fmt.Errorf("[RECOVERED]: %w", recoverer.DefaultPanicHandler(c, r))
app.Use(recoverer.New(recoverer.Config{PanicHandler: func(c fiber.Ctx, r any) error {
return &MyCustomRecoveredFromPanicError {
Inner: recoverer.DefaultPanicHandler(c, r),
}
}}))
```
================================================
FILE: docs/middleware/redirect.md
================================================
---
id: redirect
---
# Redirect
Redirect middleware maps old URLs to new ones using simple rules.
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
```go
package main
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/redirect"
)
func main() {
app := fiber.New()
app.Use(redirect.New(redirect.Config{
Rules: map[string]string{
"/old": "/new",
"/old/*": "/new/$1",
},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Get("/new", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
app.Get("/new/*", func(c fiber.Ctx) error {
return c.SendString("Wildcard: " + c.Params("*"))
})
app.Listen(":3000")
}
```
## Test
```bash
curl http://localhost:3000/old
curl http://localhost:3000/old/hello
```
## Config
| Property | Type | Description | Default |
|:-----------|:--------------------|:------------------------------------------|:-----------------------|
| Next | `func(fiber.Ctx) bool` | Skip when function returns true. | nil |
| Rules | `map[string]string` | Map paths to new ones; `$1`, `$2` insert params. | Required |
| StatusCode | `int` | HTTP code for redirects. | 302 Temporary Redirect |
## Default Config
```go
var ConfigDefault = Config{
StatusCode: fiber.StatusFound,
}
```
================================================
FILE: docs/middleware/requestid.md
================================================
---
id: requestid
---
# RequestID
The RequestID middleware generates or propagates a request identifier, adding it to the response headers and request context.
## Signatures
```go
func New(config ...Config) fiber.Handler
func FromContext(ctx any) string
```
`FromContext` accepts a `fiber.CustomCtx`, `fiber.Ctx`, a `*fasthttp.RequestCtx`, or a `context.Context`.
## Examples
Import the middleware package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/requestid"
)
```
Once your Fiber app is initialized, add the middleware like this:
```go
// Initialize default config
app.Use(requestid.New())
// Or extend your config for customization
app.Use(requestid.New(requestid.Config{
Header: "X-Custom-Header",
Generator: func() string {
return "static-id"
},
}))
```
If the request already includes the configured header, that value is reused instead of generating a new one. The middleware
rejects IDs containing characters outside the visible ASCII range (for example, control characters or obs-text bytes) and
will regenerate the value using up to three attempts from the configured generator (or SecureToken when no generator is set). When a
custom generator fails to produce a valid ID, the middleware falls back to SecureToken to keep headers RFC-compliant
across transports.
Retrieve the request ID
```go
func handler(c fiber.Ctx) error {
id := requestid.FromContext(c)
log.Printf("Request ID: %s", id)
return c.SendString("Hello, World!")
}
```
## Config
| Property | Type | Description | Default |
|:----------|:---------------------|:-----------------------------------------|:---------------|
| Next | `func(fiber.Ctx) bool` | Skip when the function returns `true`. | `nil` |
| Header | `string` | Header key used to store the request ID. | "X-Request-ID" |
| Generator | `func() string` | Function that generates the identifier. | utils.SecureToken |
## Default Config
The default config uses a cryptographically secure token generator for better security and privacy.
```go
var ConfigDefault = Config{
Next: nil,
Header: fiber.HeaderXRequestID,
Generator: utils.SecureToken,
}
```
================================================
FILE: docs/middleware/responsetime.md
================================================
---
id: responsetime
---
# ResponseTime
Response time middleware for [Fiber](https://github.com/gofiber/fiber) that measures the time spent handling a request and exposes it via a response header.
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Examples
Import the package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/responsetime"
)
```
### Default config
```go
app.Use(responsetime.New())
```
### Custom header
```go
app.Use(responsetime.New(responsetime.Config{
Header: "X-Elapsed",
}))
```
### Skip logic
```go
app.Use(responsetime.New(responsetime.Config{
Next: func(c fiber.Ctx) bool {
return c.Path() == "/healthz"
},
}))
```
## Config
| Property | Type | Description | Default |
| :------- | :--- | :---------- | :------ |
| Next | `func(c fiber.Ctx) bool` | Defines a function to skip this middleware when it returns `true`. | `nil` |
| Header | `string` | Header key used to store the measured response time. If left empty, the default header is used. | `"X-Response-Time"` |
================================================
FILE: docs/middleware/rewrite.md
================================================
---
id: rewrite
---
# Rewrite
The Rewrite middleware remaps the request path using custom rules, helping with backward compatibility and cleaner URLs.
## Signatures
```go
func New(config ...Config) fiber.Handler
```
## Config
| Property | Type | Description | Default |
|:---------|:----------------------|:------------------------------------------------------|:-----------|
| Next | `func(fiber.Ctx) bool` | Skip when function returns `true`. | `nil` |
| Rules | `map[string]string` | Map paths to new values; use `$1`, `$2` for wildcard captures.| (Required) |
:::note
Rules are stored in a map, so iteration order is undefined. Avoid overlapping patterns if precedence matters.
:::
### Examples
```go
package main
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/rewrite"
)
func main() {
app := fiber.New()
app.Use(rewrite.New(rewrite.Config{
Rules: map[string]string{
"/old": "/new",
"/old/*": "/new/$1",
},
}))
app.Get("/new", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
app.Get("/new/*", func(c fiber.Ctx) error {
return c.SendString("Wildcard: " + c.Params("*"))
})
app.Listen(":3000")
}
```
## Test
```bash
curl http://localhost:3000/old
curl http://localhost:3000/old/hello
```
================================================
FILE: docs/middleware/session.md
================================================
---
id: session
---
# Session
The Session middleware adds session management to Fiber apps through the [Storage](https://github.com/gofiber/storage) package, which offers a unified interface for multiple databases. By default, sessions live in memory, but you can plug in any storage backend.
## Table of Contents
- [Quick Start](#quick-start)
- [Usage Patterns](#usage-patterns)
- [Session Security](#session-security)
- [Session ID Extractors](#session-id-extractors)
- [Configuration](#configuration)
- [Migration Guide](#migration-guide)
- [API Reference](#api-reference)
- [Examples](#examples)
## Quick Start
```go
import (
"fmt"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/session"
)
// Basic usage
app.Use(session.New())
app.Get("/", func(c fiber.Ctx) error {
sess := session.FromContext(c)
// Get and update visits count
var visits int
if v := sess.Get("visits"); v != nil {
// Use type assertion with an ok check to prevent a panic
if vInt, ok := v.(int); ok {
visits = vInt
}
}
visits++
sess.Set("visits", visits)
return c.SendString(fmt.Sprintf("Visits: %d", visits))
})
```
### Production Configuration
```go
import (
"time"
"github.com/gofiber/fiber/v3/extractors"
"github.com/gofiber/storage/redis/v3"
)
storage := redis.New(redis.Config{
Host: "localhost",
Port: 6379,
})
app.Use(session.New(session.Config{
Storage: storage,
CookieSecure: true, // HTTPS only
CookieHTTPOnly: true, // Prevent XSS
CookieSameSite: "Lax", // CSRF protection
IdleTimeout: 30 * time.Minute, // Session timeout
AbsoluteTimeout: 24 * time.Hour, // Maximum session life
Extractor: extractors.FromCookie("__Host-session_id"),
}))
Notes:
- AbsoluteTimeout must be greater than or equal to IdleTimeout; otherwise, the middleware panics during configuration.
- If CookieSameSite is set to "None", the middleware automatically forces CookieSecure=true when setting the cookie.
```
## Usage Patterns
### Middleware Pattern (Recommended)
This pattern automatically manages the session lifecycle and is recommended for most applications.
```go
// Setup middleware
app.Use(session.New())
// Use in handlers
app.Post("/login", func(c fiber.Ctx) error {
sess := session.FromContext(c)
// Session is automatically saved when handler returns
sess.Set("user_id", 123)
sess.Set("authenticated", true)
return c.Redirect("/dashboard")
})
```
**Benefits:**
- Automatic session saving
- Automatic resource cleanup
- No manual lifecycle management
- Thread-safe operations
### Store Pattern (Advanced)
Use the store pattern for background tasks or when you need direct access to sessions.
```go
import (
"context"
"log"
"time"
)
store := session.NewStore()
// In background tasks
func backgroundTask(sessionID string) {
sess, err := store.GetByID(context.Background(), sessionID)
if err != nil {
return
}
defer sess.Release() // Important: Manual cleanup required
// Modify session
sess.Set("last_task", time.Now())
// Manual save required
if err := sess.Save(); err != nil {
log.Printf("Failed to save session: %v", err)
}
}
```
**Requirements:**
- Must call `sess.Release()` when done
- Must call `sess.Save()` to persist changes
- Handle errors manually
## Session Security
### Authentication Flow
Understanding session lifecycle during authentication is crucial for security.
#### Basic Login/Logout
```go
app.Post("/login", func(c fiber.Ctx) error {
sess := session.FromContext(c)
email := c.FormValue("email")
password := c.FormValue("password")
// Simple credential validation (use proper authentication in production)
if email == "admin@example.com" && password == "secret" {
// Important: Regenerate the session ID to prevent fixation
// This changes the session ID while preserving existing data
if err := sess.Regenerate(); err != nil {
return c.Status(500).SendString("Session error")
}
// Add authentication data to existing session
sess.Set("user_id", 1)
sess.Set("authenticated", true)
return c.Redirect("/dashboard")
}
return c.Status(401).SendString("Invalid credentials")
})
app.Post("/logout", func(c fiber.Ctx) error {
sess := session.FromContext(c)
// Complete session reset (clears all data + new session ID)
if err := sess.Reset(); err != nil {
return c.Status(500).SendString("Session error")
}
return c.Redirect("/")
})
```
#### Cart Preservation During Login
```go
app.Post("/login", func(c fiber.Ctx) error {
sess := session.FromContext(c)
// Validate credentials (implement your own validation)
email := c.FormValue("email")
password := c.FormValue("password")
if !isValidUser(email, password) {
return c.Status(401).JSON(fiber.Map{"error": "Invalid credentials"})
}
// Important: Regenerate the session ID to prevent fixation
// This changes the session ID while preserving existing data
if err := sess.Regenerate(); err != nil {
return c.Status(500).JSON(fiber.Map{"error": "Session error"})
}
// Add authentication data to existing session
sess.Set("user_id", getUserID(email))
sess.Set("authenticated", true)
sess.Set("login_time", time.Now())
return c.JSON(fiber.Map{"status": "logged in"})
})
```
### Security Methods Comparison
| Method | Session ID | Session Data | Use Case |
|--------|------------|--------------|----------|
| `Regenerate()` | ✅ Changes | ✅ Preserved | Login, privilege escalation |
| `Reset()` | ✅ Changes | ❌ Cleared | Logout, security breach |
| `Destroy()` | ⚪ Unchanged | ❌ Cleared | Clear data only |
### Common Security Mistakes
❌ **Session Fixation Vulnerability:**
```go
// DANGEROUS: Keeping same session ID after login
app.Post("/login", func(c fiber.Ctx) error {
sess := session.FromContext(c)
// Validate user...
sess.Set("user_id", userID) // Attacker can hijack this session!
return c.Redirect("/dashboard")
})
```
✅ **Secure Implementation:**
```go
// SECURE: Always regenerate session ID after authentication
app.Post("/login", func(c fiber.Ctx) error {
sess := session.FromContext(c)
// Validate user...
if err := sess.Regenerate(); err != nil { // Prevents session fixation
return err
}
sess.Set("user_id", userID)
return c.Redirect("/dashboard")
})
```
### Authentication Middleware
This is a basic example of an authentication middleware that checks if a user is logged in before accessing protected routes.
```go
// Authentication check middleware
func RequireAuth(c fiber.Ctx) error {
sess := session.FromContext(c)
if sess == nil {
return c.Redirect("/login")
}
// Check if user is authenticated
if sess.Get("authenticated") != true {
return c.Redirect("/login")
}
return c.Next()
}
// Usage
app.Use("/dashboard", RequireAuth)
app.Use("/admin", RequireAuth)
```
### Automatic Session Expiration
Sessions automatically expire based on your configuration:
```go
app.Use(session.New(session.Config{
IdleTimeout: 30 * time.Minute, // Auto-expire after 30 min of inactivity
AbsoluteTimeout: 24 * time.Hour, // Force expire after 24 hours regardless of activity
}))
```
**How it works:**
- `IdleTimeout`: Storage automatically removes sessions after inactivity period
- Any route that uses the middleware will reset the idle timer
- Calling `sess.Save()` will also reset the idle timer
- `AbsoluteTimeout`: Sessions are forcibly expired after maximum duration
- No manual cleanup required - the storage layer handles this
## Session ID Extractors
This middleware uses the shared extractors module for session ID extraction. See the [Extractors Guide](../guide/extractors) for more details.
### Built-in Extractors
```go
// Cookie-based (recommended for web apps)
extractors.FromCookie("session_id")
// Header-based (recommended for APIs)
extractors.FromHeader("X-Session-ID")
// Authorization header (read-only)
extractors.FromAuthHeader("Bearer")
// Form data
extractors.FromForm("session_id")
// URL query parameter
extractors.FromQuery("session_id")
// URL path parameter
extractors.FromParam("id")
```
**Session Response Behavior:**
- Cookie extractors: set cookie in the response
- Header extractors (non-Authorization): set header in the response
- Authorization header, Query, Form, Param, Custom: read-only (no response values are set)
### Multiple Sources with Fallback
```go
app.Use(session.New(session.Config{
Extractor: extractors.Chain(
extractors.FromCookie("session_id"), // Try cookie first
extractors.FromHeader("X-Session-ID"), // Then header
extractors.FromQuery("session_id"), // Finally query
),
}))
```
**Response Behavior with Chained Extractors:**
Only cookie and non-Authorization header extractors contribute to response setting. Others are read-only.
- Cookie + Header (non-Auth) extractors: both cookie and header are set
- Only Cookie extractors: only cookie is set
- Only Header (non-Auth) extractors: only header is set
- Any mix that includes Authorization/Query/Form/Param/Custom: those sources are read-only
```go
// This will set both cookie and header in response
extractors.Chain(
extractors.FromCookie("session_id"),
extractors.FromHeader("X-Session-ID")
)
// This will set only cookie in response
extractors.Chain(
extractors.FromCookie("session_id"),
extractors.FromQuery("session_id") // Ignored for response
)
// This will set nothing in response (read-only mode)
extractors.Chain(
extractors.FromQuery("session_id"),
extractors.FromForm("session_id")
)
```
### Custom Extractors (Session-specific)
Prefer the helper constructors from the extractors module. See the Extractors Guide for the full API; below are session-specific examples and notes.
```go
// Authorization Bearer tokens (read-only for sessions)
// The session middleware will NOT set Authorization back in the response.
app.Use(session.New(session.Config{
Extractor: extractors.FromAuthHeader("Bearer"),
}))
```
```go
// Custom read-only header via FromCustom (read-only for sessions)
app.Use(session.New(session.Config{
Extractor: extractors.FromCustom("X-Custom-Session", func(c fiber.Ctx) (string, error) {
v := c.Get("X-Custom-Session")
if v == "" { return "", extractors.ErrNotFound }
return v, nil
}),
}))
```
## Configuration
### Storage Options
```go
import (
"github.com/gofiber/storage/redis/v3"
"github.com/gofiber/storage/postgres/v3"
)
// Redis (recommended for production)
redisStorage := redis.New(redis.Config{
Host: "localhost",
Port: 6379,
Password: "",
Database: 0,
})
// PostgreSQL
pgStorage := postgres.New(postgres.Config{
Host: "localhost",
Port: 5432,
Database: "sessions",
Username: "user",
Password: "pass",
})
app.Use(session.New(session.Config{
Storage: redisStorage,
}))
```
### Production Security Settings
```go
import (
"log"
"time"
"github.com/gofiber/utils/v2"
"github.com/gofiber/fiber/v3/extractors"
)
app.Use(session.New(session.Config{
// Storage
Storage: redisStorage,
// Security
CookieSecure: true, // HTTPS only (required in production)
CookieHTTPOnly: true, // No JavaScript access (prevents XSS)
CookieSameSite: "Lax", // CSRF protection
// Session Management
IdleTimeout: 30 * time.Minute, // Inactivity timeout
AbsoluteTimeout: 24 * time.Hour, // Maximum session duration
// Cookie Settings
CookiePath: "/",
CookieDomain: "example.com",
CookieSessionOnly: false, // Persist across browser restarts
// Session ID
Extractor: extractors.FromCookie("__Host-session_id"),
KeyGenerator: utils.SecureToken,
// Error Handling
ErrorHandler: func(c fiber.Ctx, err error) {
log.Printf("Session error: %v", err)
},
}))
```
### Custom Types
Session data supports basic Go types by default:
- `string`, `int`, `int8`, `int16`, `int32`, `int64`
- `uint`, `uint8`, `uint16`, `uint32`, `uint64`
- `bool`, `float32`, `float64`
- `[]byte`, `complex64`, `complex128`
- `interface{}`
For custom types (structs, maps, slices), you must register them for encoding/decoding:
```go
import "fmt"
type User struct {
ID int `json:"id"`
Name string `json:"name"`
Role string `json:"role"`
}
// Method 1: Using NewWithStore
func main() {
app := fiber.New()
sessionMiddleware, store := session.NewWithStore()
store.RegisterType(User{}) // Register custom type
app.Use(sessionMiddleware)
app.Get("/", func(c fiber.Ctx) error {
sess := session.FromContext(c)
// Use custom type
sess.Set("user", User{ID: 123, Name: "John", Role: "admin"})
user, ok := sess.Get("user").(User)
if ok {
return c.JSON(fiber.Map{"user": user.Name, "role": user.Role})
}
return c.SendString("No user found")
})
app.Listen(":3000")
}
```
```go
// Method 2: Using separate store
store := session.NewStore()
store.RegisterType(User{})
app.Use(session.New(session.Config{
Store: store,
}))
// Usage in handlers
sess.Set("user", User{ID: 123, Name: "John", Role: "admin"})
user, ok := sess.Get("user").(User)
if ok {
fmt.Printf("User: %s (Role: %s)", user.Name, user.Role)
}
```
**Important Notes:**
- Custom types must be registered before using them in sessions
- Registration must happen during application startup
- All instances of the application must register the same types
- Types are encoded using Go's `gob` package
## Migration Guide
### v2 to v3 Breaking Changes
1. **Function Signature**: `session.New()` now returns middleware handler, not store
2. **Session ID Extraction**: `KeyLookup` replaced with `Extractor` functions
3. **Lifecycle Management**: Manual `Release()` required for store pattern
4. **Timeout Handling**: `Expiration` split into `IdleTimeout` and `AbsoluteTimeout`
### Migration Examples
**v2 Code:**
```go
store := session.New(session.Config{
KeyLookup: "cookie:session_id",
})
app.Get("/", func(c fiber.Ctx) error {
sess, err := store.Get(c)
if err != nil {
return err
}
// Session automatically saved and released
sess.Set("key", "value")
return nil
})
```
**v3 Middleware Pattern (Recommended):**
```go
app.Use(session.New(session.Config{
Extractor: extractors.FromCookie("session_id"),
}))
app.Get("/", func(c fiber.Ctx) error {
sess := session.FromContext(c)
// Session automatically saved and released
sess.Set("key", "value")
return nil
})
```
**v3 Store Pattern (Advanced):**
```go
store := session.NewStore(session.Config{
Extractor: extractors.FromCookie("session_id"),
})
app.Get("/", func(c fiber.Ctx) error {
sess, err := store.Get(c)
if err != nil {
return err
}
defer sess.Release() // Manual cleanup required
sess.Set("key", "value")
return sess.Save() // Manual save required
})
```
### KeyLookup to Extractor Migration
| v2 KeyLookup | v3 Extractor |
|---------------------------------|------------------------------------------------------------------------------------|
| `"cookie:session_id"` | `extractors.FromCookie("session_id")` |
| `"header:X-Session-ID"` | `extractors.FromHeader("X-Session-ID")` |
| `"query:session_id"` | `extractors.FromQuery("session_id")` |
| `"form:session_id"` | `extractors.FromForm("session_id")` |
| `"cookie:sid,header:X-Sid"` | `extractors.Chain(extractors.FromCookie("sid"), extractors.FromHeader("X-Sid"))` |
## API Reference
### Middleware Methods (Recommended)
```go
sess := session.FromContext(c)
// Data operations
sess.Get(key any) any
sess.Set(key, value any)
sess.Delete(key any)
sess.Keys() []any
// Session management
sess.ID() string
sess.Fresh() bool
sess.Regenerate() error // Change ID, keep data
sess.Reset() error // Change ID, clear data
sess.Destroy() error // Keep ID, clear data
// Store access
sess.Store() *session.Store
```
`FromContext` accepts a `fiber.CustomCtx`, `fiber.Ctx`, a `*fasthttp.RequestCtx`, or a `context.Context`.
### Store Methods
```go
store := session.NewStore()
// Store operations
store.Get(c fiber.Ctx) (*session.Session, error)
store.GetByID(ctx context.Context, sessionID string) (*session.Session, error)
store.Reset(ctx context.Context) error
store.Delete(ctx context.Context, sessionID string) error
// Type registration
store.RegisterType(interface{})
```
### Session Methods (Store Pattern)
```go
sess, err := store.Get(c)
defer sess.Release() // Required!
// Same methods as middleware, plus:
sess.Save() error // Manual save required
sess.SetIdleTimeout(duration) // Per-session timeout
sess.Release() // Manual cleanup required
```
### Extractor Functions
```go
// Built-in extractors (import "github.com/gofiber/fiber/v3/extractors")
extractors.FromCookie(key string) extractors.Extractor
extractors.FromHeader(key string) extractors.Extractor
extractors.FromQuery(key string) extractors.Extractor
extractors.FromForm(key string) extractors.Extractor
extractors.FromParam(key string) extractors.Extractor
// Chaining
extractors.Chain(extractors ...extractors.Extractor) extractors.Extractor
```
### Config Properties
| Property | Type | Description | Default |
|---------------------|-----------------------------|-----------------------------|--------------------------------------------|
| `Store` | `*session.Store` | Pre-built session store (use when you need to share/register types) | `nil` (auto-created) |
| `Storage` | `fiber.Storage` | Session storage backend (used when creating a store if `Store` is nil) | `memory.New()` |
| `Extractor` | `extractors.Extractor` | Session ID extraction | `extractors.FromCookie("session_id")` |
| `KeyGenerator` | `func() string` | Session ID generator | `utils.SecureToken` |
| `IdleTimeout` | `time.Duration` | Inactivity timeout | `30 * time.Minute` |
| `AbsoluteTimeout` | `time.Duration` | Maximum session duration | `0` (unlimited) |
| `CookieSecure` | `bool` | HTTPS only | `false` |
| `CookieHTTPOnly` | `bool` | No JavaScript access | `false` |
| `CookieSameSite` | `string` | SameSite attribute | `"Lax"` |
| `CookiePath` | `string` | Cookie path | `""` |
| `CookieDomain` | `string` | Cookie domain | `""` |
| `CookieSessionOnly` | `bool` | Session cookie | `false` |
| `Next` | `func(fiber.Ctx) bool` | Skip middleware when returns true | `nil` |
| `ErrorHandler` | `func(fiber.Ctx, error)` | Error callback | `DefaultErrorHandler` |
## Examples
### E-commerce with Cart Persistence
```go
import (
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/session"
"github.com/gofiber/fiber/v3/extractors"
"github.com/gofiber/storage/redis/v3"
)
func main() {
app := fiber.New()
// Session middleware
app.Use(session.New(session.Config{
Storage: redis.New(),
CookieSecure: true,
CookieHTTPOnly: true,
CookieSameSite: "Lax",
IdleTimeout: 30 * time.Minute,
AbsoluteTimeout: 24 * time.Hour,
Extractor: extractors.FromCookie("__Host-cart_session"),
}))
// Add to cart (anonymous user)
app.Post("/cart/add", func(c fiber.Ctx) error {
sess := session.FromContext(c)
cart, _ := sess.Get("cart").([]string)
cart = append(cart, c.FormValue("item_id"))
sess.Set("cart", cart)
return c.JSON(fiber.Map{"items": len(cart)})
})
// Login (preserve session data)
app.Post("/login", func(c fiber.Ctx) error {
sess := session.FromContext(c)
// Simple validation (implement proper authentication)
email := c.FormValue("email")
password := c.FormValue("password")
if email != "user@example.com" || password != "password" {
return c.Status(401).JSON(fiber.Map{"error": "Invalid credentials"})
}
// Regenerate session ID for security
// This changes the session ID while preserving existing data
if err := sess.Regenerate(); err != nil {
return c.Status(500).JSON(fiber.Map{"error": "Session error"})
}
sess.Set("user_id", 1)
sess.Set("authenticated", true)
return c.JSON(fiber.Map{"status": "logged in"})
})
// Logout (clear everything)
app.Post("/logout", func(c fiber.Ctx) error {
sess := session.FromContext(c)
// Reset clears all data and generates new session ID
if err := sess.Reset(); err != nil {
return c.Status(500).JSON(fiber.Map{"error": "Session error"})
}
return c.JSON(fiber.Map{"status": "logged out"})
})
app.Listen(":3000")
}
// Helper functions (implement these properly in production)
func isValidUser(email, password string) bool {
return email == "user@example.com" && password == "password"
}
func getUserID(email string) int {
return 1 // Return actual user ID from database
}
```
### API with Header-based Sessions
```go
import (
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/session"
"github.com/gofiber/fiber/v3/extractors"
"github.com/gofiber/storage/redis/v3"
)
func main() {
app := fiber.New()
// API session middleware with header extraction
app.Use(session.New(session.Config{
Storage: redis.New(),
Extractor: extractors.FromHeader("X-Session-Token"),
IdleTimeout: time.Hour,
}))
// API endpoint
app.Post("/api/data", func(c fiber.Ctx) error {
sess := session.FromContext(c)
// Track API usage
count, _ := sess.Get("api_calls").(int)
count++
sess.Set("api_calls", count)
sess.Set("last_call", time.Now())
return c.JSON(fiber.Map{
"data": "some data",
"calls": count,
})
})
app.Listen(":3000")
}
```
### Multi-source Session ID Support
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/session"
"github.com/gofiber/fiber/v3/extractors"
)
func main() {
app := fiber.New()
// Support multiple sources with priority
app.Use(session.New(session.Config{
Extractor: extractors.Chain(
extractors.FromCookie("session_id"), // 1st: Cookie (web)
extractors.FromHeader("X-Session-ID"), // 2nd: Header (API)
extractors.FromQuery("session_id"), // 3rd: Query (fallback)
),
}))
app.Get("/", func(c fiber.Ctx) error {
sess := session.FromContext(c)
// Works with any of the above methods
return c.JSON(fiber.Map{
"session_id": sess.ID(),
"source": "multi-source",
})
})
app.Listen(":3000")
}
```
================================================
FILE: docs/middleware/skip.md
================================================
---
id: skip
---
# Skip
The Skip middleware wraps a handler and bypasses it when the predicate returns `true` for the current request.
## Signatures
```go
func New(handler fiber.Handler, exclude func(c fiber.Ctx) bool) fiber.Handler
```
## Examples
Import the package:
```go
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/skip"
)
```
`skip.New` accepts the handler to wrap and a predicate function. The predicate
runs for every request, and returning `true` skips the wrapped handler and
executes the next middleware in the chain.
After you initialize your Fiber app, use `skip.New` like this:
```go
func main() {
app := fiber.New()
app.Use(skip.New(BasicHandler, func(ctx fiber.Ctx) bool {
return ctx.Method() == fiber.MethodGet
}))
app.Get("/", func(ctx fiber.Ctx) error {
return ctx.SendString("It was a GET request!")
})
log.Fatal(app.Listen(":3000"))
}
func BasicHandler(ctx fiber.Ctx) error {
return ctx.SendString("It was not a GET request!")
}
```
:::tip
`app.Use` processes requests on any route and method. In the example above, the handler is skipped only for `GET`.
:::
================================================
FILE: docs/middleware/static.md
================================================
---
id: static
---
# Static
The Static middleware serves assets such as **images**, **CSS**, and **JavaScript**.
:::info
By default, it serves `index.html` when a directory is requested. Customize this behavior in the [Config](#config) options.
:::
## Signatures
```go
func New(root string, cfg ...Config) fiber.Handler
```
## Examples
Import the package:
```go
import(
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/static"
)
```
### Serving files from a directory
```go
app.Get("/*", static.New("./public"))
```
Test
```sh
curl http://localhost:3000/hello.html
curl http://localhost:3000/css/style.css
```
### Serving files from a directory with `Use`
```go
app.Use("/", static.New("./public"))
```
Test
```sh
curl http://localhost:3000/hello.html
curl http://localhost:3000/css/style.css
```
### Serving a single file
```go
app.Use("/static", static.New("./public/hello.html"))
```
Test
```sh
curl http://localhost:3000/static # will show hello.html
curl http://localhost:3000/static/john/doe # will show hello.html
```
### Serving files using os.DirFS
```go
app.Get("/files*", static.New("", static.Config{
FS: os.DirFS("files"),
Browse: true,
}))
```
Test
```sh
curl http://localhost:3000/files/css/style.css
curl http://localhost:3000/files/index.html
```
### Serving files using embed.FS
```go
//go:embed path/to/files
var myfiles embed.FS
app.Get("/files*", static.New("", static.Config{
FS: myfiles,
Browse: true,
}))
```
Test
```sh
curl http://localhost:3000/files/css/style.css
curl http://localhost:3000/files/index.html
```
### SPA (Single Page Application)
```go
app.Use("/web", static.New("", static.Config{
FS: os.DirFS("dist"),
}))
app.Get("/web*", func(c fiber.Ctx) error {
return c.SendFile("dist/index.html")
})
```
Test
```sh
curl http://localhost:3000/web/css/style.css
curl http://localhost:3000/web/index.html
curl http://localhost:3000/web
```
:::caution
To define static routes using `Get`, append the wildcard (`*`) operator at the end of the route.
:::
## Config
| Property | Type | Description | Default |
|:-----------|:------------------------|:---------------------------------------------------------------------------------------------------------------------------|:-----------------------|
| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` |
| FS | `fs.FS` | FS is the file system to serve the static files from.
You can use interfaces compatible with fs.FS like embed.FS, os.DirFS etc. | `nil` |
| Compress | `bool` | When set to true, the server tries minimizing CPU usage by caching compressed files. The middleware will compress the response using `gzip`, `brotli`, or `zstd` compression depending on the [Accept-Encoding](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding) header.
This works differently than the github.com/gofiber/compression middleware. | `false` |
| ByteRange | `bool` | When set to true, enables byte range requests. | `false` |
| Browse | `bool` | When set to true, enables directory browsing. | `false` |
| Download | `bool` | When set to true, enables direct download. | `false` |
| IndexNames | `[]string` | The names of the index files for serving a directory. | `[]string{"index.html"}` |
| CacheDuration | `time.Duration` | Expiration duration for inactive file handlers.
Use a negative time.Duration to disable it. | `10 * time.Second` |
| MaxAge | `int` | The value for the Cache-Control HTTP-header that is set on the file response. MaxAge is defined in seconds. | `0` |
| ModifyResponse | `fiber.Handler` | ModifyResponse defines a function that allows you to alter the response. | `nil` |
| NotFoundHandler | `fiber.Handler` | NotFoundHandler defines a function to handle when the path is not found. | `nil` |
When **Download** is enabled, the response includes a `Content-Disposition` header with the requested filename. Non-ASCII names use the `filename*` parameter as defined by [RFC 6266](https://www.rfc-editor.org/rfc/rfc6266) and [RFC 8187](https://www.rfc-editor.org/rfc/rfc8187).
:::info
You can set `CacheDuration` config property to `-1` to disable caching.
:::
## Default Config
```go
var ConfigDefault = Config{
IndexNames: []string{"index.html"},
CacheDuration: 10 * time.Second,
}
```
================================================
FILE: docs/middleware/timeout.md
================================================
---
id: timeout
---
# Timeout
The timeout middleware enforces a deadline on handler execution. It wraps handlers with
`context.WithTimeout`, exposes the derived context through `c.Context()`, and
returns `408 Request Timeout` when the deadline is exceeded.
## How It Works
When a timeout occurs, the middleware **returns immediately** without waiting for the
handler to finish. This is achieved through Fiber's **Abandon mechanism**:
1. The handler runs in a goroutine with a timeout context
2. On timeout, the middleware marks the context as "abandoned" and returns `408` immediately
3. The handler goroutine can continue safely (e.g., for cleanup) without blocking the response
4. A background cleanup goroutine waits for the handler to finish and performs context cleanup
Handlers can detect the timeout by listening on `c.Context().Done()` and return early.
This is the recommended pattern for cooperative cancellation.
If a handler panics, the middleware catches it and returns `500 Internal Server Error`.
## Known limitations
- Timed-out requests abandon their `fiber.Ctx` to avoid data races with the core
request handler (including the `ErrorHandler`). These contexts are **not**
returned to the pool, so each timed-out request leaks a context. Calling
`ForceRelease` is only safe if you can guarantee that no goroutine (including
Fiber internals) will touch the context anymore; the timeout middleware
intentionally does not call it.
:::caution
`timeout.New` wraps your final handler and can't be added with `app.Use` or
used in a middleware chain. Register it per route and avoid calling
`c.Next()` inside the wrapped handler—doing so will panic.
:::
## Signatures
```go
func New(handler fiber.Handler, config ...timeout.Config) fiber.Handler
```
## Examples
### Basic example
The following program times out any request that takes longer than two seconds.
The handler simulates work with `sleepWithContext`, which stops when the
context is canceled:
```go
package main
import (
"context"
"fmt"
"log"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/timeout"
)
func sleepWithContext(ctx context.Context, d time.Duration) error {
select {
case <-time.After(d):
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func main() {
app := fiber.New()
handler := func(c fiber.Ctx) error {
delay, _ := time.ParseDuration(c.Params("delay") + "ms")
if err := sleepWithContext(c.Context(), delay); err != nil {
return fmt.Errorf("%w: execution error", err)
}
return c.SendString("finished")
}
app.Get("/sleep/:delay", timeout.New(handler, timeout.Config{
Timeout: 2 * time.Second,
}))
log.Fatal(app.Listen(":3000"))
}
```
Use these requests to see the middleware in action:
```bash
curl -i http://localhost:3000/sleep/1000 # finishes within the timeout
curl -i http://localhost:3000/sleep/3000 # returns 408 Request Timeout
```
## Config
| Property | Type | Description | Default |
|:------------|:-------------------|:---------------------------------------------------------------------|:-------|
| Next | `func(fiber.Ctx) bool` | Function to skip this middleware when it returns `true`. | `nil` |
| Timeout | `time.Duration` | Timeout duration for requests. `0` or a negative value disables the timeout. | `0` |
| OnTimeout | `fiber.Handler` | Handler executed when a timeout occurs. Defaults to returning `fiber.ErrRequestTimeout`. | `nil` |
| Errors | `[]error` | Custom errors treated as timeout errors. | `nil` |
### Use with a custom error
```go
var ErrFooTimeOut = errors.New("foo context canceled")
func main() {
app := fiber.New()
h := func(c fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContextWithCustomError(c.Context(), sleepTime); err != nil {
return fmt.Errorf("%w: execution error", err)
}
return nil
}
app.Get("/foo/:sleepTime", timeout.New(h, timeout.Config{Timeout: 2 * time.Second, Errors: []error{ErrFooTimeOut}}))
log.Fatal(app.Listen(":3000"))
}
func sleepWithContextWithCustomError(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return ErrFooTimeOut
case <-timer.C:
}
return nil
}
```
### Sample usage with a database call
```go
func main() {
app := fiber.New()
db, _ := gorm.Open(postgres.Open("postgres://localhost/foodb"), &gorm.Config{})
handler := func(ctx fiber.Ctx) error {
tran := db.WithContext(ctx.Context()).Begin()
if tran = tran.Exec("SELECT pg_sleep(50)"); tran.Error != nil {
return tran.Error
}
if tran = tran.Commit(); tran.Error != nil {
return tran.Error
}
return nil
}
app.Get("/foo", timeout.New(handler, timeout.Config{Timeout: 10 * time.Second}))
log.Fatal(app.Listen(":3000"))
}
```
================================================
FILE: docs/partials/routing/handler.md
================================================
---
id: route-handlers
title: Route Handlers
---
import Reference from '@site/src/components/reference';
Registers a route bound to a specific [HTTP method](https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods).
```go title="Signatures"
// HTTP methods
func (app *App) Get(path string, handler any, handlers ...any) Router
func (app *App) Head(path string, handler any, handlers ...any) Router
func (app *App) Post(path string, handler any, handlers ...any) Router
func (app *App) Put(path string, handler any, handlers ...any) Router
func (app *App) Delete(path string, handler any, handlers ...any) Router
func (app *App) Connect(path string, handler any, handlers ...any) Router
func (app *App) Options(path string, handler any, handlers ...any) Router
func (app *App) Trace(path string, handler any, handlers ...any) Router
func (app *App) Patch(path string, handler any, handlers ...any) Router
// Add allows you to specify multiple methods at once
// The provided handlers are executed in order, starting with `handler` and then the variadic `handlers`.
func (app *App) Add(methods []string, path string, handler any, handlers ...any) Router
// All will register the route on all HTTP methods
// Almost the same as app.Use but not bound to prefixes
func (app *App) All(path string, handler any, handlers ...any) Router
```
Fiber's adapter converts a variety of handler shapes to native
`func(fiber.Ctx) error` callbacks. It currently recognizes seventeen cases (the
numbers below match the comments in `toFiberHandler` inside `adapter.go`). This
lets you mix Fiber-style handlers with Express-style callbacks and even reuse
`net/http` or `fasthttp` functions.
### Fiber-native handlers (cases 1–2)
- **Case 1.** `fiber.Handler` — the canonical `func(fiber.Ctx) error` form.
- **Case 2.** `func(fiber.Ctx)` — Fiber runs the function and treats it as if it
returned `nil`.
### Express-style request handlers (cases 3–12)
- **Case 3.** `func(fiber.Req, fiber.Res) error`
- **Case 4.** `func(fiber.Req, fiber.Res)`
- **Case 5.** `func(fiber.Req, fiber.Res, func() error) error`
- **Case 6.** `func(fiber.Req, fiber.Res, func() error)`
- **Case 7.** `func(fiber.Req, fiber.Res, func()) error`
- **Case 8.** `func(fiber.Req, fiber.Res, func())`
- **Case 9.** `func(fiber.Req, fiber.Res, func(error))`
- **Case 10.** `func(fiber.Req, fiber.Res, func(error)) error`
- **Case 11.** `func(fiber.Req, fiber.Res, func(error) error)`
- **Case 12.** `func(fiber.Req, fiber.Res, func(error) error) error`
The adapter injects a `next` callback when your signature accepts one. Fiber
propagates downstream errors from `c.Next()` back through the wrapper, so
returning those errors remains optional. If you never call the injected `next`
function, the handler chain stops, matching Express semantics.
When you accept `next` callbacks that take an `error`, calling `next(nil)`
continues the chain and passing a non-nil error short-circuits with that error.
If the handler itself returns an error, Fiber prioritizes that value over any
recorded `next` error.
### net/http handlers (cases 13–15)
- **Case 13.** `http.HandlerFunc`
- **Case 14.** `http.Handler`
- **Case 15.** `func(http.ResponseWriter, *http.Request)`
:::caution Compatibility overhead
Fiber adapts these handlers through `fasthttpadaptor`. They do not receive
`fiber.Ctx`, cannot call `c.Next()`, and therefore always terminate the handler
chain. The compatibility layer also adds more overhead than running a native
Fiber handler, so prefer the other forms when possible.
:::
### fasthttp handlers (cases 16–17)
- **Case 16.** `fasthttp.RequestHandler`
- **Case 17.** `func(*fasthttp.RequestCtx) error`
fasthttp handlers run with full access to the underlying `fasthttp.RequestCtx`.
They are expected to manage the response directly. Fiber will propagate any
error returned by the `func(*fasthttp.RequestCtx) error` variant but otherwise
does not inspect the context state.
```go title="Examples"
// Simple GET handler (Fiber accepts both func(fiber.Ctx) and func(fiber.Ctx) error)
app.Get("/api/list", func(c fiber.Ctx) error {
return c.SendString("I'm a GET request!")
})
// Reuse an existing net/http handler without manual adaptation
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
app.Get("/foo", httpHandler)
// Align with Express-style handlers using fiber.Req and fiber.Res helpers (works
// for middleware and routes alike)
app.Use(func(req fiber.Req, res fiber.Res, next func() error) error {
if req.IP() == "192.168.1.254" {
return res.SendStatus(fiber.StatusForbidden)
}
return next()
})
app.Get("/express", func(req fiber.Req, res fiber.Res) error {
return res.SendString("Hello from Express-style handlers!")
})
// Mount a fasthttp.RequestHandler directly
app.Get("/bar", func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(fiber.StatusAccepted)
})
// Simple POST handler
app.Post("/api/register", func(c fiber.Ctx) error {
return c.SendString("I'm a POST request!")
})
```
## Use
Can be used for middleware packages and prefix catchers. Prefixes now require either an exact match or a slash boundary, so `/john` matches `/john` and `/john/doe` but not `/johnnnnn`. Parameter tokens like `:name`, `:name?`, `*`, and `+` are still expanded before the boundary check runs.
```go title="Signature"
func (app *App) Use(args ...any) Router
// Fiber inspects args to support these common usage patterns:
// - app.Use(handler, handlers ...any)
// - app.Use(path string, handler, handlers ...any)
// - app.Use(paths []string, handler, handlers ...any)
// - app.Use(path string, subApp *App)
```
Each handler argument can independently be a Fiber handler (with or without an
`error` return), an Express-style callback, a `net/http` handler, or any other
supported shape including fasthttp callbacks that return errors.
```go title="Examples"
// Match any request
app.Use(func(c fiber.Ctx) error {
return c.Next()
})
// Match request starting with /api
app.Use("/api", func(c fiber.Ctx) error {
return c.Next()
})
// Match requests starting with /api or /home (multiple-prefix support)
app.Use([]string{"/api", "/home"}, func(c fiber.Ctx) error {
return c.Next()
})
// Attach multiple handlers
app.Use("/api", func(c fiber.Ctx) error {
c.Set("X-Custom-Header", random.String(32))
return c.Next()
}, func(c fiber.Ctx) error {
return c.Next()
})
// Mount a sub-app
app.Use("/api", api)
```
================================================
FILE: docs/whats_new.md
================================================
---
id: whats_new
title: 🆕 What's New in v3
sidebar_position: 2
toc_max_heading_level: 4
---
## 🎉 Welcome
We are excited to announce the release of Fiber v3! 🚀
In this guide, we'll walk you through the most important changes in Fiber `v3` and show you how to migrate your existing Fiber `v2` applications to Fiber `v3`.
### 🛠️ Migration tool
Fiber v3 introduces a CLI-powered migration helper. Install the CLI and let
it update your project automatically:
```bash
go install github.com/gofiber/cli/fiber@latest
fiber migrate --to v3
```
See the [migration guide](#-migration-guide) for more details and options.
Here's a quick overview of the changes in Fiber `v3`:
- [🚀 App](#-app)
- [🎣 Hooks](#-hooks)
- [🚀 Listen](#-listen)
- [🗺️ Router](#-router)
- [🧠 Context](#-context)
- [📎 Binding](#-binding)
- [🔬 Extractors Package](#-extractors-package)
- [🔄️ Redirect](#-redirect)
- [🌎 Client package](#-client-package)
- [🧰 Generic functions](#-generic-functions)
- [🛠️ Utils](#utils)
- [🧩 Services](#-services)
- [📃 Log](#-log)
- [📦 Storage Interface](#-storage-interface)
- [🧬 Middlewares](#-middlewares)
- [Important Change for Accessing Middleware Data](#important-change-for-accessing-middleware-data)
- [Adaptor](#adaptor)
- [BasicAuth](#basicauth)
- [Cache](#cache)
- [CORS](#cors)
- [CSRF](#csrf)
- [Compression](#compression)
- [EncryptCookie](#encryptcookie)
- [Favicon](#favicon)
- [Filesystem](#filesystem)
- [Healthcheck](#healthcheck)
- [KeyAuth](#keyauth)
- [Logger](#logger)
- [Monitor](#monitor)
- [Proxy](#proxy)
- [Recover](#recover)
- [Session](#session)
- [🔌 Addons](#-addons)
- [📋 Migration guide](#-migration-guide)
## Dropping support for old Go versions
Fiber `v3` requires Go `1.25` or later. Update your toolchain to `1.25+` before upgrading so the module `go` directive and standard library features align with the new minimum version.
## 🚀 App
We have made several changes to the Fiber app, including:
- **Listen**: The `Listen` method has been unified with the configuration, allowing for more streamlined setup.
- **Static**: The `Static` method has been removed and its functionality has been moved to the [static middleware](./middleware/static.md).
- **app.Config properties**: Several properties have been moved to the listen configuration:
- `DisableStartupMessage`
- `EnablePrefork` (previously `Prefork`)
- `EnablePrintRoutes`
- `ListenerNetwork` (previously `Network`)
- **Trusted Proxy Configuration**: The `EnabledTrustedProxyCheck` has been moved to `app.Config.TrustProxy`, and `TrustedProxies` has been moved to `TrustProxyConfig.Proxies`. Additionally, `ProxyHeader` must be set to read client IPs from proxy headers (e.g., `X-Forwarded-For`).
- **XMLDecoder Config Property**: The `XMLDecoder` property has been added to allow usage of 3rd-party XML libraries in XML binder.
### New Methods
- **RegisterCustomBinder**: Allows for the registration of custom binders.
- **RegisterCustomConstraint**: Allows for the registration of custom constraints.
- **NewWithCustomCtx**: Initialize an app with a custom context in one step.
- **State**: Provides a global state for the application, which can be used to store and retrieve data across the application. Check out the [State](./api/state) method for further details.
- **NewErrorf**: Allows variadic parameters when creating formatted errors.
- **GetBytes / GetString**: Helpers that detach values only when `Immutable` is enabled and the data still references request or response buffers. Access via `c.App().GetString` and `c.App().GetBytes`.
- **ReloadViews**: Lets you re-run the configured view engine's `Load()` logic at runtime, including guard rails for missing or nil view engines so development hot-reload hooks can refresh templates safely.
#### Custom Route Constraints
Custom route constraints enable you to define your own validation rules for route parameters.
Use `RegisterCustomConstraint` to add a constraint type that implements the `CustomConstraint` interface.
Example
```go
type UlidConstraint struct {
fiber.CustomConstraint
}
func (*UlidConstraint) Name() string {
return "ulid"
}
func (*UlidConstraint) Execute(param string, args ...string) bool {
_, err := ulid.Parse(param)
return err == nil
}
app.RegisterCustomConstraint(&UlidConstraint{})
app.Get("/login/:id", func(c fiber.Ctx) error {
return c.SendString("User " + c.Params("id"))
})
```
### Removed Methods
- **Mount**: Use `app.Use()` instead.
- **ListenTLS**: Use `app.Listen()` with `tls.Config`.
- **ListenTLSWithCertificate**: Use `app.Listen()` with `tls.Config`.
- **ListenMutualTLS**: Use `app.Listen()` with `tls.Config`.
- **ListenMutualTLSWithCertificate**: Use `app.Listen()` with `tls.Config`.
### Method Changes
- **Test**: The `Test` method has replaced the timeout parameter with a configuration parameter. `0` or lower represents no timeout.
- **Listen**: Now has a configuration parameter.
- **Listener**: Now has a configuration parameter.
### Custom Ctx Interface in Fiber v3
Fiber v3 introduces a customizable `Ctx` interface, allowing developers to extend and modify the context to fit their needs. This feature provides greater flexibility and control over request handling.
#### Idea Behind Custom Ctx Classes
The idea behind custom `Ctx` classes is to give developers the ability to extend the default context with additional methods and properties tailored to the specific requirements of their application. This allows for better request handling and easier implementation of specific logic.
#### NewWithCustomCtx
`NewWithCustomCtx` creates the application and sets the custom context factory at initialization time.
```go title="Signature"
func NewWithCustomCtx(fn func(app *App) CustomCtx, config ...Config) *App
```
Example
```go
package main
import (
"log"
"github.com/gofiber/fiber/v3"
)
type CustomCtx struct {
fiber.DefaultCtx
}
func (c *CustomCtx) CustomMethod() string {
return "custom value"
}
func main() {
app := fiber.NewWithCustomCtx(func(app *fiber.App) fiber.CustomCtx {
return &CustomCtx{
DefaultCtx: *fiber.NewDefaultCtx(app),
}
})
app.Get("/", func(c fiber.Ctx) error {
customCtx := c.(*CustomCtx)
return c.SendString(customCtx.CustomMethod())
})
log.Fatal(app.Listen(":3000"))
}
```
This example creates a `CustomCtx` with an extra `CustomMethod` and initializes the app with `NewWithCustomCtx`.
### Configurable TLS Minimum Version
We have added support for configuring the TLS minimum version. This field allows you to set the TLS minimum version for TLSAutoCert and the server listener.
```go
app.Listen(":444", fiber.ListenConfig{TLSMinVersion: tls.VersionTLS12})
```
#### TLS AutoCert support (ACME / Let's Encrypt)
We have added native support for automatic certificates management from Let's Encrypt and any other ACME-based providers.
```go
// Certificate manager
certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
// Replace with your domain name
HostPolicy: autocert.HostWhitelist("example.com"),
// Folder to store the certificates
Cache: autocert.DirCache("./certs"),
}
app.Listen(":444", fiber.ListenConfig{
AutoCertManager: certManager,
})
```
### MIME Constants
`MIMEApplicationJavaScript` and `MIMEApplicationJavaScriptCharsetUTF8` are deprecated. Use `MIMETextJavaScript` and `MIMETextJavaScriptCharsetUTF8` instead.
## 🎣 Hooks
We have made several changes to the Fiber hooks, including:
- Added new shutdown hooks to provide better control over the shutdown process:
- `OnPreShutdown` - Executes before the server starts shutting down
- `OnPostShutdown` - Executes after the server has shut down, receives any shutdown error
- `OnPreStartupMessage` - Executes before the startup message is printed, allowing customization of the banner and info entries
- `OnPostStartupMessage` - Executes after the startup message is printed, allowing post-startup logic
- Deprecated `OnShutdown` in favor of the new pre/post shutdown hooks
- Improved shutdown hook execution order and reliability
- Added mutex protection for hook registration and execution
Important: When using shutdown hooks, ensure app.Listen() is called in a separate goroutine:
```go
// Correct usage
go app.Listen(":3000")
// ... register shutdown hooks
app.Shutdown()
// Incorrect usage - hooks won't work
app.Listen(":3000") // This blocks
app.Shutdown() // Never reached
```
## 🚀 Listen
We have made several changes to the Fiber listen, including:
- Removed `OnShutdownError` and `OnShutdownSuccess` from `ListenConfig` in favor of using the `OnPostShutdown` hook, which receives the shutdown error
```go
app := fiber.New()
// Before - using ListenConfig callbacks
app.Listen(":3000", fiber.ListenConfig{
OnShutdownError: func(err error) {
log.Printf("Shutdown error: %v", err)
},
OnShutdownSuccess: func() {
log.Println("Shutdown successful")
},
})
// After - using OnPostShutdown hook
app.Hooks().OnPostShutdown(func(err error) error {
if err != nil {
log.Printf("Shutdown error: %v", err)
} else {
log.Println("Shutdown successful")
}
return nil
})
go app.Listen(":3000")
```
This change simplifies the shutdown handling by consolidating the shutdown callbacks into a single hook that receives the error status.
- Added support for Unix domain sockets via `ListenerNetwork` and `UnixSocketFileMode`
```go
// v2 - Requires manual deletion of old file and permissions change
app := fiber.New(fiber.Config{
Network: "unix",
})
os.Remove("app.sock")
app.Hooks().OnListen(func(fiber.ListenData) error {
return os.Chmod("app.sock", 0770)
})
app.Listen("app.sock")
// v3 - Fiber does it for you
app := fiber.New()
app.Listen("app.sock", fiber.ListenConfig{
ListenerNetwork: fiber.NetworkUnix,
UnixSocketFileMode: 0770,
})
```
- Added `TLSConfig` to `ListenConfig` so external providers can supply certificates via `GetCertificate`. Prefer `TLSConfig` when configuring TLS; when set, it is cloned and takes precedence over other TLS fields.
```go
app := fiber.New()
app.Listen(":443", fiber.ListenConfig{
TLSConfig: &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return myProvider.Certificate(info.ServerName)
},
},
})
```
- Expanded `ListenData` with versioning, handler, process, and PID metadata, plus dedicated startup message hooks for customization. Check out the [Hooks](./api/hooks#startup-message-customization) documentation for further details.
```go title="Customize the startup message"
package main
import (
"fmt"
"os"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Hooks().OnPreStartupMessage(func(sm *fiber.PreStartupMessageData) error {
sm.BannerHeader = "FOOBER " + sm.Version + "\n-------"
// Optional: you can also remove old entries
// sm.ResetEntries()
sm.AddInfo("git-hash", "Git hash", os.Getenv("GIT_HASH"))
sm.AddInfo("prefork", "Prefork", fmt.Sprintf("%v", sm.Prefork), 15)
return nil
})
app.Hooks().OnPostStartupMessage(func(sm *fiber.PostStartupMessageData) error {
if !sm.Disabled && !sm.IsChild && !sm.Prevented {
fmt.Println("startup completed")
}
return nil
})
app.Listen(":5000")
}
```
## 🗺 Router
We have slightly adapted our router interface
### Handler compatibility
Fiber now ships with a routing adapter (see `adapter.go`) that understands native Fiber handlers alongside `net/http` and `fasthttp` handlers. Route registration helpers accept a required `handler` argument plus optional additional `handlers`, all typed as `any`, and the adapter transparently converts supported handler styles so you can keep using the ecosystem functions you're familiar with.
To align even closer with Express, you can also register handlers that accept the new `fiber.Req` and `fiber.Res` helper interfaces. The adapter understands both two-argument (`func(fiber.Req, fiber.Res)`) and three-argument (`func(fiber.Req, fiber.Res, func() error)`) callbacks, regardless of whether they return an `error`. It also accepts Express-style `next` callbacks that take an `error` (`func(error)` or `func(error) error`). When you include the optional `next` callback, Fiber wires it to `c.Next()` for you so middleware continues to behave as expected. Calling `next(nil)` continues the chain, while passing a non-nil error short-circuits and returns that error. If your handler returns an `error`, the value returned from the injected `next()` bubbles straight back to the caller. When your handler omits an `error` return, Fiber records the result of `next()` and returns it after your function exits so downstream failures still propagate.
| Case | Handler signature | Notes |
| ---- | ----------------- | ----- |
| 1 | `fiber.Handler` | Native Fiber handler. |
| 2 | `func(fiber.Ctx)` | Fiber handler without an error return. |
| 3 | `func(fiber.Req, fiber.Res) error` | Express-style request handler with error return. |
| 4 | `func(fiber.Req, fiber.Res)` | Express-style request handler without error return. |
| 5 | `func(fiber.Req, fiber.Res, func() error) error` | Express-style middleware with an error-returning `next` callback and handler error return. |
| 6 | `func(fiber.Req, fiber.Res, func() error)` | Express-style middleware with an error-returning `next` callback. |
| 7 | `func(fiber.Req, fiber.Res, func()) error` | Express-style middleware with a no-argument `next` callback and handler error return. |
| 8 | `func(fiber.Req, fiber.Res, func())` | Express-style middleware with a no-argument `next` callback. |
| 9 | `func(fiber.Req, fiber.Res, func(error))` | Express-style middleware with an error-accepting `next` callback. |
| 10 | `func(fiber.Req, fiber.Res, func(error)) error` | Express-style middleware with an error-accepting `next` callback and handler error return. |
| 11 | `func(fiber.Req, fiber.Res, func(error) error)` | Express-style middleware with an error-accepting `next` callback that returns an error. |
| 12 | `func(fiber.Req, fiber.Res, func(error) error) error` | Express-style middleware with an error-accepting `next` callback that returns an error and handler error return. |
| 13 | `http.HandlerFunc` | Standard-library handler function adapted through `fasthttpadaptor`. |
| 14 | `http.Handler` | Standard-library handler implementation; pointer receivers must be non-nil. |
| 15 | `func(http.ResponseWriter, *http.Request)` | Standard-library function handlers via `fasthttpadaptor`. |
| 16 | `fasthttp.RequestHandler` | Direct fasthttp handler without error return. |
| 17 | `func(*fasthttp.RequestCtx) error` | fasthttp handler that returns an error to Fiber. |
### Route chaining
`RouteChain` is a new helper inspired by [`Express`](https://expressjs.com/en/api.html#app.route) that makes it easy to declare a stack of handlers on the same path, while the existing `Route` helper stays available for prefix encapsulation.
```go
RouteChain(path string) Register
```
Example
```go
app.RouteChain("/api").RouteChain("/user/:id?")
.Get(func(c fiber.Ctx) error {
// Get user
return c.JSON(fiber.Map{"message": "Get user", "id": c.Params("id")})
})
.Post(func(c fiber.Ctx) error {
// Create user
return c.JSON(fiber.Map{"message": "User created"})
})
.Put(func(c fiber.Ctx) error {
// Update user
return c.JSON(fiber.Map{"message": "User updated", "id": c.Params("id")})
})
.Delete(func(c fiber.Ctx) error {
// Delete user
return c.JSON(fiber.Map{"message": "User deleted", "id": c.Params("id")})
})
```
You can find more information about `app.RouteChain` and `app.Route` in the API documentation ([RouteChain](./api/app#routechain), [Route](./api/app#route)).
### Automatic HEAD routes for GET
Fiber now auto-registers a `HEAD` route whenever you add a `GET` route. The generated handler chain matches the `GET` chain so status codes and headers stay in sync while the response body remains empty, ensuring `HEAD` clients observe the same metadata as a `GET` consumer.
```go title="GET now enables HEAD automatically"
app := fiber.New()
app.Get("/health", func(c fiber.Ctx) error {
c.Set("X-Service", "api")
return c.SendString("OK")
})
// HEAD /health reuses the GET middleware chain and returns headers only.
```
You can still register explicit `HEAD` handlers for any `GET` route, and they continue to win when you add them:
```go title="Override the generated HEAD handler"
app.Head("/health", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusNoContent)
})
```
Prefer to manage `HEAD` routes yourself? Disable the feature through `fiber.Config.DisableHeadAutoRegister`:
```go title="Disable automatic HEAD registration"
handler := func(c fiber.Ctx) error {
c.Set("X-Service", "api")
return c.SendString("OK")
}
app := fiber.New(fiber.Config{DisableHeadAutoRegister: true})
app.Get("/health", handler) // HEAD /health now returns 405 unless you add it manually.
```
Auto-generated `HEAD` routes appear in tooling such as `app.Stack()` and cover the same routing scenarios as their `GET` counterparts, including groups, mounted apps, dynamic parameters, and static file handlers.
### Middleware registration
We have aligned our method for middlewares closer to [`Express`](https://expressjs.com/en/api.html#app.use) and now also support the [`Use`](./api/app#use) of multiple prefixes.
Prefix matching is now stricter: partial matches must end at a slash boundary (or be an exact match). This keeps `/api` middleware from running on `/apiv1` while still allowing `/api/:version` style patterns that leverage route parameters, optional segments, or wildcards.
Registering a subapp is now also possible via the [`Use`](./api/app#use) method instead of the old `app.Mount` method.
Example
```go
// register multiple prefixes
app.Use([]string{"/v1", "/v2"}, func(c fiber.Ctx) error {
// Middleware for /v1 and /v2
return c.Next()
})
// define subapp
api := fiber.New()
api.Get("/user", func(c fiber.Ctx) error {
return c.SendString("User")
})
// register subapp
app.Use("/api", api)
```
To enable the routing changes above we had to slightly adjust the signature of the `Add` method.
```diff
- Add(method, path string, handlers ...Handler) Router
+ Add(methods []string, path string, handler any, handlers ...any) Router
```
### Test Config
The `app.Test()` method now allows users to customize their test configurations:
Example
```go
// Create a test app with a handler to test
app := fiber.New()
app.Get("/", func(c fiber.Ctx) {
return c.SendString("hello world")
})
// Define the HTTP request and custom TestConfig to test the handler
req := httptest.NewRequest(MethodGet, "/", nil)
testConfig := fiber.TestConfig{
Timeout: 0,
FailOnTimeout: false,
}
// Test the handler using the request and testConfig
resp, err := app.Test(req, testConfig)
```
To provide configurable testing capabilities, we had to change
the signature of the `Test` method.
```diff
- Test(req *http.Request, timeout ...time.Duration) (*http.Response, error)
+ Test(req *http.Request, config ...fiber.TestConfig) (*http.Response, error)
```
The `TestConfig` struct provides the following configuration options:
- `Timeout`: The duration to wait before timing out the test. Use 0 for no timeout.
- `FailOnTimeout`: Controls the behavior when a timeout occurs:
- When true, the test will return an `os.ErrDeadlineExceeded` if the test exceeds the `Timeout` duration.
- When false, the test will return the partial response received before timing out.
If a custom `TestConfig` isn't provided, then the following will be used:
```go
testConfig := fiber.TestConfig{
Timeout: time.Second,
FailOnTimeout: true,
}
```
**Note:** Using this default is **NOT** the same as providing an empty `TestConfig` as an argument to `app.Test()`.
An empty `TestConfig` is the equivalent of:
```go
testConfig := fiber.TestConfig{
Timeout: 0,
FailOnTimeout: false,
}
```
## 🧠 Context
### New Features
- Cookie now allows Partitioned cookies for [CHIPS](https://developers.google.com/privacy-sandbox/3pcd/chips) support. CHIPS (Cookies Having Independent Partitioned State) is a feature that improves privacy by allowing cookies to be partitioned by top-level site, mitigating cross-site tracking.
- Cookie automatic security enforcement: When setting a cookie with `SameSite=None`, Fiber automatically sets `Secure=true` as required by RFC 6265bis and modern browsers (Chrome, Firefox, Safari). This ensures compliance with the "None" SameSite policy. See [Mozilla docs](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#none) and [Chrome docs](https://developers.google.com/search/blog/2020/01/get-ready-for-new-samesitenone-secure) for details.
- `Ctx` now implements the [context.Context](https://pkg.go.dev/context#Context) interface, replacing the former `UserContext` helpers.
### New Methods
- **AutoFormat**: Similar to Express.js, automatically formats the response based on the request's `Accept` header.
- **Deadline**: For implementing `context.Context`.
- **Done**: For implementing `context.Context`.
- **Err**: For implementing `context.Context`.
- **Host**: Similar to Express.js, returns the host name of the request.
- **Port**: Similar to Express.js, returns the port number of the request.
- **IsProxyTrusted**: Checks the trustworthiness of the remote IP.
- **Reset**: Resets context fields for server handlers.
- **Schema**: Similar to Express.js, returns the schema (HTTP or HTTPS) of the request.
- **SendEarlyHints**: Sends `HTTP 103 Early Hints` status code with `Link` headers so browsers can preload resources while the final response is being prepared.
- **SendStream**: Similar to Express.js, sends a stream as the response.
- **SendStreamWriter**: Sends a stream using a writer function.
- **SendString**: Similar to Express.js, sends a string as the response.
- **String**: Similar to Express.js, converts a value to a string.
- **Value**: For implementing `context.Context`. Returns request-scoped value from Locals.
- **Context()**: Returns a `context.Context` that can be used outside the handler.
- **SetContext**: Sets the base `context.Context` returned by `Context()` for propagating deadlines or values.
- **ViewBind**: Binds data to a view, replacing the old `Bind` method.
- **CBOR**: Introducing [CBOR](https://cbor.io/) binary encoding format for both request & response body. CBOR is a binary data serialization format which is both compact and efficient, making it ideal for use in web applications.
- **MsgPack**: Introducing [MsgPack](https://msgpack.org/) binary encoding format for both request & response body. MsgPack is a binary serialization format that is more efficient than JSON, making it ideal for high-performance applications.
- **Drop**: Terminates the client connection silently without sending any HTTP headers or response body. This can be used for scenarios where you want to block certain requests without notifying the client, such as mitigating DDoS attacks or protecting sensitive endpoints from unauthorized access.
- **End**: Similar to Express.js, immediately flushes the current response and closes the underlying connection.
- **AcceptsLanguagesExtended**: Matches language ranges using RFC 4647 Extended Filtering with wildcard subtags.
- **FullURL**: Returns the full request URL (scheme + host + original URL).
- **RequestID**: Returns the request identifier from the response or request headers.
- **UserAgent**: Returns the `User-Agent` request header.
- **Referer**: Returns the `Referer` request header.
- **AcceptLanguage**: Returns the `Accept-Language` request header.
- **AcceptEncoding**: Returns the `Accept-Encoding` request header.
- **HasHeader**: Reports whether the request includes a header with the given key.
- **MediaType**: Returns the MIME type from the `Content-Type` header without parameters.
- **Charset**: Returns the `charset` parameter from the `Content-Type` header.
- **IsJSON**: Reports whether the `Content-Type` header is JSON.
- **IsForm**: Reports whether the `Content-Type` header is form-encoded.
- **IsMultipart**: Reports whether the `Content-Type` header is multipart form data.
- **AcceptsJSON**: Reports whether the `Accept` header allows JSON.
- **AcceptsHTML**: Reports whether the `Accept` header allows HTML.
- **AcceptsXML**: Reports whether the `Accept` header allows XML.
- **AcceptsEventStream**: Reports whether the `Accept` header allows `text/event-stream`.
- **Matched**: Detects when the current request path matched a registered route.
- **IsMiddleware**: Indicates if the current handler was registered as middleware.
- **HasBody**: Quickly checks whether the request includes a body.
- **OverrideParam**: Overwrites the value of an existing route parameter, or does nothing if the parameter does not exist
- **IsWebSocket**: Reports if the request attempts a WebSocket upgrade.
- **IsPreflight**: Identifies CORS preflight requests before handlers run.
### Removed Methods
- **AllParams**: Use `c.Bind().URI()` instead.
- **ParamsInt**: Use `Params` with generic types.
- **QueryBool**: Use `Query` with generic types.
- **QueryFloat**: Use `Query` with generic types.
- **QueryInt**: Use `Query` with generic types.
- **BodyParser**: Use `c.Bind().Body()` instead.
- **CookieParser**: Use `c.Bind().Cookie()` instead.
- **ParamsParser**: Use `c.Bind().URI()` instead.
- **RedirectToRoute**: Use `c.Redirect().Route()` instead.
- **RedirectBack**: Use `c.Redirect().Back()` instead.
- **ReqHeaderParser**: Use `c.Bind().Header()` instead.
- **UserContext**: Removed. `Ctx` itself now satisfies `context.Context`; pass `c` directly where a `context.Context` is required.
- **SetUserContext**: Removed. Use `SetContext` and `Context()` or `context.WithValue` on `c` to store additional request-scoped values.
### Changed Methods
- **Bind**: Now used for binding instead of view binding. Use `c.ViewBind()` for view binding.
- **Format**: Parameter changed from `body interface{}` to `handlers ...ResFmt`.
- **Redirect**: Use `c.Redirect().To()` instead.
- **SendFile**: Now supports different configurations using a config parameter.
- **Attachment and Download**: Non-ASCII filenames now use `filename*` as
specified by [RFC 6266](https://www.rfc-editor.org/rfc/rfc6266) and
[RFC 8187](https://www.rfc-editor.org/rfc/rfc8187).
- **Context()**: Renamed to `RequestCtx()` to access the underlying `fasthttp.RequestCtx`.
### SendEarlyHints
`SendEarlyHints` sends an informational [`103 Early Hints`](https://developer.chrome.com/docs/web-platform/early-hints) response with `Link` headers based on the provided `hints` argument. This allows a browser to start preloading assets while the server is still preparing the final response.
```go
hints := []string{"; rel=preload; as=script"}
app.Get("/early", func(c fiber.Ctx) error {
if err := c.SendEarlyHints(hints); err != nil {
return err
}
return c.SendString("done")
})
```
Older HTTP/1.1 clients may ignore these interim responses or handle them inconsistently.
### SendStreamWriter
In v3, we introduced support for buffered streaming with the addition of the `SendStreamWriter` method:
```go
func (c Ctx) SendStreamWriter(streamWriter func(w *bufio.Writer)) error
```
With this new method, you can implement:
- Server-Side Events (SSE)
- Large file downloads
- Live data streaming
```go
app.Get("/sse", func(c fiber.Ctx) error {
c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
return c.SendStreamWriter(func(w *bufio.Writer) {
for {
fmt.Fprintf(w, "event: my-event\n")
fmt.Fprintf(w, "data: Hello SSE\n\n")
if err := w.Flush(); err != nil {
log.Print("Client disconnected!")
return
}
}
})
})
```
You can find more details about this feature in [/docs/api/ctx.md](./api/ctx.md).
### Drop
In v3, we introduced support to silently terminate requests through `Drop`.
```go
func (c Ctx) Drop() error
```
With this method, you can:
- Block certain requests without notifying the client to mitigate DDoS attacks
- Protect sensitive endpoints from unauthorized access without leaking errors.
:::caution
While this feature adds the ability to drop connections, it is still **highly recommended** to use additional
measures (such as **firewalls**, **proxies**, etc.) to further protect your server endpoints by blocking
malicious connections before the server establishes a connection.
:::
```go
app.Get("/", func(c fiber.Ctx) error {
if c.IP() == "192.168.1.1" {
return c.Drop()
}
return c.SendString("Hello World!")
})
```
You can find more details about this feature in [/docs/api/ctx.md](./api/ctx.md).
### End
In v3, we introduced a new method to match the Express.js API's `res.end()` method.
```go
func (c Ctx) End() error
```
With this method, you can:
- Stop middleware from controlling the connection after a handler further up the method chain
by immediately flushing the current response and closing the connection.
- Use `return c.End()` as an alternative to `return nil`
```go
app.Use(func (c fiber.Ctx) error {
err := c.Next()
if err != nil {
log.Println("Got error: %v", err)
return c.SendString(err.Error()) // Will be unsuccessful since the response ended below
}
return nil
})
app.Get("/hello", func (c fiber.Ctx) error {
query := c.Query("name", "")
if query == "" {
_ = c.SendString("You don't have a name?")
_ = c.End() // Closes the underlying connection; errors intentionally ignored
return errors.New("No name provided")
}
return c.SendString("Hello, " + query + "!")
})
```
---
## 📎 Binding
Fiber v3 introduces a new binding mechanism that simplifies the process of binding request data to structs. The new binding system supports binding from various sources such as URL parameters, query parameters, headers, and request bodies. This unified approach makes it easier to handle different types of request data in a consistent manner.
### New Features
- Unified binding from URL parameters, query parameters, headers, and request bodies.
- Support for custom binders and constraints.
- Improved error handling and validation.
- Support multipart file binding for `*multipart.FileHeader`, `*[]*multipart.FileHeader`, and `[]*multipart.FileHeader` field types.
- Support for unified binding (`Bind().All()`) with defined precedence order: (URI -> Body -> Query -> Headers -> Cookies). [Learn more](./api/bind.md#all).
- Support MsgPack binding for request body.
Example
```go
type User struct {
ID int `uri:"id"`
Name string `json:"name"`
Email string `json:"email"`
}
app.Post("/user/:id", func(c fiber.Ctx) error {
var user User
if err := c.Bind().Body(&user); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(user)
})
```
In this example, the `Bind` method is used to bind the request body to the `User` struct. The `Body` method of the `Bind` class performs the actual binding.
## 🔬 Extractors Package
Fiber v3 introduces a new shared `extractors` package that consolidates value extraction utilities previously duplicated across middleware packages. This package provides a unified API for extracting values from headers, cookies, query parameters, form data, and URL parameters with built-in chain/fallback logic and security considerations.
### Key Features
- **Unified API**: Single package for extracting values from headers, cookies, query parameters, form data, and URL parameters
- **Chain Logic**: Built-in fallback mechanism to try multiple extraction sources in order
- **Source Awareness**: Source inspection capabilities for security-sensitive operations
- **Type Safety**: Strongly typed extraction with proper error handling
- **Performance**: Optimized extraction functions with minimal overhead
### Available Extractors
- `FromAuthHeader(authScheme string)`: Extract from Authorization header with scheme support
- `FromCookie(key string)`: Extract from HTTP cookies
- `FromParam(param string)`: Extract from URL path parameters
- `FromForm(param string)`: Extract from form data
- `FromHeader(header string)`: Extract from custom HTTP headers
- `FromQuery(param string)`: Extract from URL query parameters
- `FromCustom(key string, extractor func(c fiber.Ctx) (string, error))`: Define custom extraction logic with metadata
- `Chain(extractors ...Extractor)`: Chain multiple extractors with fallback logic
### Usage Example
```go
import "github.com/gofiber/fiber/v3/extractors"
// Extract API key from multiple sources with fallback
apiKeyExtractor := extractors.Chain(
extractors.FromHeader("X-API-Key"),
extractors.FromQuery("api_key"),
extractors.FromCookie("api_key"),
)
app.Use(func(c fiber.Ctx) error {
apiKey, err := apiKeyExtractor.Extract(c)
if err != nil {
return c.Status(401).SendString("API key required")
}
// Use apiKey for authentication
return c.Next()
})
```
### Migration from Middleware-Specific Extractors
Middleware packages in Fiber v3 now use the shared extractors package instead of maintaining their own extraction logic. This provides:
- **Code Deduplication**: Eliminates ~500+ lines of duplicated extraction code
- **Consistency**: Standardized extraction behavior across all middleware
- **Maintainability**: Single source of truth for extraction logic
- **Security**: Unified security considerations and warnings
## 🔄 Redirect
Fiber v3 enhances the redirect functionality by introducing new methods and improving existing ones. The new redirect methods provide more flexibility and control over the redirection process.
### New Methods
- `Redirect().To()`: Redirects to a specific URL.
- `Redirect().Route()`: Redirects to a named route.
- `Redirect().Back()`: Redirects to the previous URL.
Example
```go
app.Get("/old", func(c fiber.Ctx) error {
return c.Redirect().To("/new")
})
app.Get("/new", func(c fiber.Ctx) error {
return c.SendString("Welcome to the new route!")
})
```
### Changed behavior
:::info
The default redirect status code has been updated from `302 Found` to `303 See Other` to ensure more consistent behavior across different browsers.
:::
## 🌎 Client package
The Gofiber client has been completely rebuilt. It includes numerous new features such as Cookiejar, request/response hooks, and more.
You can take a look to [client docs](./client/rest.md) to see what's new with the client.
### Configuration improvements
The v3 client centralizes common configuration on the client instance and lets you override it per request with `client.Config`.
You can define base URLs, defaults (headers, cookies, path parameters, timeouts), and toggle path normalization once, while still
using axios-style helpers for each call.
```go
cc := client.New().
SetBaseURL("https://api.service.local").
AddHeader("Authorization", "Bearer ").
SetTimeout(5 * time.Second).
SetPathParam("tenant", "acme")
resp, err := cc.Get("/users/:tenant/:id", client.Config{
PathParam: map[string]string{"id": "42"},
Param: map[string]string{"include": "profile"},
DisablePathNormalizing: true,
})
if err != nil {
panic(err)
}
defer resp.Close()
fmt.Println(resp.StatusCode(), resp.String())
```
### Fasthttp transport integration
- `client.NewWithHostClient` and `client.NewWithLBClient` allow you to plug existing `fasthttp` clients directly into Fiber while keeping retries, redirects, and hook logic consistent.
- Dialer, TLS, and proxy helpers now update every host client inside a load balancer, so complex pools inherit the same configuration.
- The Fiber client exposes `Do`, `DoTimeout`, `DoDeadline`, and `CloseIdleConnections`, matching the surface area of the wrapped fasthttp transports.
## 🧰 Generic functions
Fiber v3 introduces new generic functions that provide additional utility and flexibility for developers. These functions are designed to simplify common tasks and improve code readability.
### New Generic Functions
- **StoreInContext**: Stores request-scoped values in both `c.Locals()` and the request `context.Context`, so the same value can be read through middleware `FromContext` helpers and direct locals access.
- **Convert**: Converts a value with a specified converter function and default value.
- **Locals**: Retrieves or sets local values within a request context.
- **Params**: Retrieves route parameters and can handle various types of route parameters.
- **Query**: Retrieves the value of a query parameter from the request URI and can handle various types of query parameters.
- **GetReqHeader**: Returns the HTTP request header specified by the field and can handle various types of header values.
`fiber.Config.PassLocalsToContext` is now available to control whether `StoreInContext` also synchronizes values with request `context.Context` for Fiber-backed contexts. The default is `false` for backward compatibility. `ValueFromContext` continues reading Fiber-backed values from `c.Locals()`.
### Example
Convert
```go
package main
import (
"strconv"
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Get("/convert", func(c fiber.Ctx) error {
value, err := fiber.Convert[int](c.Query("value"), strconv.Atoi, 0)
if err != nil {
return c.Status(fiber.StatusBadRequest).SendString(err.Error())
}
return c.JSON(value)
})
app.Listen(":3000")
}
```
```sh
curl "http://localhost:3000/convert?value=123"
# Output: 123
curl "http://localhost:3000/convert?value=abc"
# Output: "failed to convert: strconv.Atoi: parsing \"abc\": invalid syntax"
```
Locals
```go
package main
import (
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Use("/user/:id", func(c fiber.Ctx) error {
// ask database for user
// ...
// set local values from database
fiber.Locals[string](c, "user", "john")
fiber.Locals[int](c, "age", 25)
// ...
return c.Next()
})
app.Get("/user/*", func(c fiber.Ctx) error {
// get local values
name := fiber.Locals[string](c, "user")
age := fiber.Locals[int](c, "age")
// ...
return c.JSON(fiber.Map{"name": name, "age": age})
})
app.Listen(":3000")
}
```
```sh
curl "http://localhost:3000/user/5"
# Output: {"name":"john","age":25}
```
Params
```go
package main
import (
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Get("/params/:id", func(c fiber.Ctx) error {
id := fiber.Params[int](c, "id", 0)
return c.JSON(id)
})
app.Listen(":3000")
}
```
```sh
curl "http://localhost:3000/params/123"
# Output: 123
curl "http://localhost:3000/params/abc"
# Output: 0
```
Query
```go
package main
import (
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Get("/query", func(c fiber.Ctx) error {
age := fiber.Query[int](c, "age", 0)
return c.JSON(age)
})
app.Listen(":3000")
}
```
```sh
curl "http://localhost:3000/query?age=25"
# Output: 25
curl "http://localhost:3000/query?age=abc"
# Output: 0
```
GetReqHeader
```go
package main
import (
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
app.Get("/header", func(c fiber.Ctx) error {
userAgent := fiber.GetReqHeader[string](c, "User-Agent", "Unknown")
return c.JSON(userAgent)
})
app.Listen(":3000")
}
```
```sh
curl -H "User-Agent: CustomAgent" "http://localhost:3000/header"
# Output: "CustomAgent"
curl "http://localhost:3000/header"
# Output: "Unknown"
```
## 🛠️ Utils {#utils}
Fiber v3 removes the built-in `utils` directory and now imports utility helpers from the separate [`github.com/gofiber/utils/v2`](https://github.com/gofiber/utils) module. See the [migration guide](#utils-migration) for detailed replacement steps and examples.
The `github.com/gofiber/utils` module also introduces new helpers like `ParseInt`, `ParseUint`, `Walk`, `ReadFile`, and `Timestamp`.
## 🧩 Services
Fiber v3 introduces a new feature called Services. This feature allows developers to quickly start services that the application depends on, removing the need to manually provision things like database servers, caches, or message brokers, to name a few.
### Example
Adding a service
```go
package main
import (
"context"
"github.com/gofiber/fiber/v3"
)
type myService struct {
img string
// ...
}
// Start initializes and starts the service. It implements the [fiber.Service] interface.
func (s *myService) Start(ctx context.Context) error {
// start the service
return nil
}
// String returns a string representation of the service.
// It is used to print a human-readable name of the service in the startup message.
// It implements the [fiber.Service] interface.
func (s *myService) String() string {
return s.img
}
// State returns the current state of the service.
// It implements the [fiber.Service] interface.
func (s *myService) State(ctx context.Context) (string, error) {
return "running", nil
}
// Terminate stops and removes the service. It implements the [fiber.Service] interface.
func (s *myService) Terminate(ctx context.Context) error {
// stop the service
return nil
}
func main() {
cfg := &fiber.Config{}
cfg.Services = append(cfg.Services, &myService{img: "postgres:latest"})
cfg.Services = append(cfg.Services, &myService{img: "redis:latest"})
app := fiber.New(*cfg)
// ...
}
```
Output
```sh
$ go run . -v
_______ __
/ ____(_) /_ ___ _____
/ /_ / / __ \/ _ \/ ___/
/ __/ / / /_/ / __/ /
/_/ /_/_.___/\___/_/ v3.0.0
--------------------------------------------------
INFO Server started on: http://127.0.0.1:3000 (bound on host 0.0.0.0 and port 3000)
INFO Services: 2
INFO 🧩 [ RUNNING ] postgres:latest
INFO 🧩 [ RUNNING ] redis:latest
INFO Total handlers count: 2
INFO Prefork: Disabled
INFO PID: 12279
INFO Total process count: 1
```
## 📃 Log
`fiber.AllLogger[T]` interface now has a new generic type parameter `T` and a method called `Logger`. This method can be used to get the underlying logger instance from the Fiber logger middleware. This is useful when you want to configure the logger middleware with a custom logger and still want to access the underlying logger instance with the appropriate type.
You can find more details about this feature in [/docs/api/log.md](./api/log.md#logger).
`logger.Config` now supports a new field called `ForceColors`. This field allows you to force the logger to always use colors, even if the output is not a terminal. This is useful when you want to ensure that the logs are always colored, regardless of the output destination.
```go
package main
import "github.com/gofiber/fiber/v3/middleware/logger"
app.Use(logger.New(logger.Config{
ForceColors: true,
}))
```
## 📦 Storage Interface
The storage interface has been updated to include new subset of methods with `WithContext` suffix. These methods allow you to pass a context to the storage operations, enabling better control over timeouts and cancellation if needed. This is particularly useful when storage implementations used outside of the Fiber core, such as in background jobs or long-running tasks.
**New Methods Signatures:**
```go
// GetWithContext gets the value for the given key with a context.
// `nil, nil` is returned when the key does not exist
GetWithContext(ctx context.Context, key string) ([]byte, error)
// SetWithContext stores the given value for the given key
// with an expiration value, 0 means no expiration.
// Empty key or value will be ignored without an error.
SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error
// DeleteWithContext deletes the value for the given key with a context.
// It returns no error if the storage does not contain the key,
DeleteWithContext(ctx context.Context, key string) error
// ResetWithContext resets the storage and deletes all keys with a context.
ResetWithContext(ctx context.Context) error
```
## 🧬 Middlewares
### Important Change for Accessing Middleware Data
In Fiber v3, many middlewares that previously set values in `c.Locals()` using string keys (e.g., `c.Locals("requestid")`) have been updated. To align with Go's context best practices and prevent key collisions, these middlewares now store their specific data in the request's context using unexported keys of custom types.
This means that directly accessing these values via `c.Locals("some_string_key")` will no longer work for such middleware-provided data.
**How to Access Middleware Data in v3:**
Each affected middleware now provides dedicated exported functions to retrieve its specific data from the context. You should use these functions instead of relying on string-based lookups in `c.Locals()`.
Examples include:
- `requestid.FromContext(c)`
- `csrf.TokenFromContext(c)`
- `csrf.HandlerFromContext(c)`
- `session.FromContext(c)`
- `basicauth.UsernameFromContext(c)`
- `keyauth.TokenFromContext(c)`
When used with the Logger middleware, the recommended approach is to use the `CustomTags` feature of the logger, which allows you to call these specific `FromContext` functions. See the [Logger](#logger) section for more details.
### Adaptor
The adaptor middleware has been significantly optimized for performance and efficiency. Key improvements include reduced response times, lower memory usage, and fewer memory allocations. These changes make the middleware more reliable and capable of handling higher loads effectively. Enhancements include the introduction of a `sync.Pool` for managing `fasthttp.RequestCtx` instances and better HTTP request and response handling between net/http and fasthttp contexts.
Incoming body sizes now respect the Fiber app's configured `BodyLimit` (falling back to the default when unset) when running Fiber from `net/http` through the adaptor, returning `413 Request Entity Too Large` for oversized payloads.
| Payload Size | Metric | V2 | V3 | Percent Change |
| ------------ | -------------- | ------------ | ----------- | -------------- |
| 100KB | Execution Time | 1056 ns/op | 588.6 ns/op | -44.25% |
| | Memory Usage | 2644 B/op | 254 B/op | -90.39% |
| | Allocations | 16 allocs/op | 5 allocs/op | -68.75% |
| 500KB | Execution Time | 1061 ns/op | 562.9 ns/op | -46.94% |
| | Memory Usage | 2644 B/op | 248 B/op | -90.62% |
| | Allocations | 16 allocs/op | 5 allocs/op | -68.75% |
| 1MB | Execution Time | 1080 ns/op | 629.7 ns/op | -41.68% |
| | Memory Usage | 2646 B/op | 267 B/op | -89.91% |
| | Allocations | 16 allocs/op | 5 allocs/op | -68.75% |
| 5MB | Execution Time | 1093 ns/op | 540.3 ns/op | -50.58% |
| | Memory Usage | 2654 B/op | 254 B/op | -90.43% |
| | Allocations | 16 allocs/op | 5 allocs/op | -68.75% |
| 10MB | Execution Time | 1044 ns/op | 533.1 ns/op | -48.94% |
| | Memory Usage | 2665 B/op | 258 B/op | -90.32% |
| | Allocations | 16 allocs/op | 5 allocs/op | -68.75% |
| 25MB | Execution Time | 1069 ns/op | 540.7 ns/op | -49.42% |
| | Memory Usage | 2706 B/op | 289 B/op | -89.32% |
| | Allocations | 16 allocs/op | 5 allocs/op | -68.75% |
| 50MB | Execution Time | 1137 ns/op | 554.6 ns/op | -51.21% |
| | Memory Usage | 2734 B/op | 298 B/op | -89.10% |
| | Allocations | 16 allocs/op | 5 allocs/op | -68.75% |
### BasicAuth
The BasicAuth middleware now validates the `Authorization` header more rigorously and sets security-focused response headers. Passwords must be provided in **hashed** form (e.g. SHA-256 or bcrypt) rather than plaintext. The default challenge includes the `charset="UTF-8"` parameter and disables caching. Responses also set a `Vary: Authorization` header to prevent caching based on credentials. Passwords are no longer stored in the request context. A `Charset` option controls the value used in the challenge header.
A new `HeaderLimit` option restricts the maximum length of the `Authorization` header (default: `8192` bytes).
The `Authorizer` function now receives the current `fiber.Ctx` as a third argument, allowing credential checks to incorporate request context.
### Cache
We are excited to introduce a new option in our caching middleware: Cache Invalidator. This feature provides greater control over cache management, allowing you to define custom conditions for invalidating cache entries.
The middleware now emits `Cache-Control` headers by default via the new `DisableCacheControl` flag, increases the default `Expiration` from `1 minute` to `5 minutes`, and applies a new `MaxBytes` limit of `1 MB` (previously unlimited).
Additionally, the caching middleware has been optimized to avoid caching non-cacheable status codes, as defined by the [HTTP standards](https://datatracker.ietf.org/doc/html/rfc7231#section-6.1). This improvement enhances cache accuracy and reduces unnecessary cache storage usage.
Cached responses now include an RFC-compliant Age header, providing a standardized indication of how long a response has been stored in cache since it was originally generated. This enhancement improves HTTP compliance and facilitates better client-side caching strategies.
Cache keys are now redacted in logs and error messages by default, and a `DisableValueRedaction` boolean (default `false`) lets you opt out when you need the raw value for troubleshooting.
:::note
The deprecated `Store` and `Key` options have been removed in v3. Use `Storage` and `KeyGenerator` instead.
:::
### ResponseTime
A new response time middleware measures how long each request takes to process and adds the duration to the response headers.
By default it writes the elapsed time to `X-Response-Time`, and you can change the header name. A `Next` hook lets you skip
endpoints such as health checks.
### CORS
We've made some changes to the CORS middleware to improve its functionality and flexibility. Here's what's new:
#### New Struct Fields
- `Config.AllowPrivateNetwork`: This new field is a boolean that allows you to control whether private networks are allowed. This is related to the [Private Network Access (PNA)](https://wicg.github.io/private-network-access/) specification from the [Web Incubator Community Group (WICG)](https://wicg.io/). When set to `true`, the CORS middleware will allow CORS preflight requests from private networks and respond with the `Access-Control-Allow-Private-Network: true` header. This could be useful in development environments or specific use cases, but should be done with caution due to potential security risks.
#### Updated Struct Fields
We've updated several fields from a single string (containing comma-separated values) to slices, allowing for more explicit declaration of multiple values. Here are the updated fields:
- `Config.AllowOrigins`: Now accepts a slice of strings, each representing an allowed origin.
- `Config.AllowMethods`: Now accepts a slice of strings, each representing an allowed method.
- `Config.AllowHeaders`: Now accepts a slice of strings, each representing an allowed header.
- `Config.ExposeHeaders`: Now accepts a slice of strings, each representing an exposed header.
Additionally, panic messages and logs redact misconfigured origins by default, and a `DisableValueRedaction` flag (default `false`) lets you reveal them when necessary.
### Compression
- Added support for `zstd` compression alongside `gzip`, `deflate`, and `brotli`.
- Strong `ETag` values are now recomputed for compressed payloads so validators remain accurate.
- Compression is bypassed for responses that already specify `Content-Encoding`, for range requests or `206` statuses, and when either side sends `Cache-Control: no-transform`.
- `HEAD` requests still negotiate compression so `Content-Encoding`, `Content-Length`, `ETag`, and `Vary` match a corresponding `GET`, but the body is omitted.
- `Vary: Accept-Encoding` is merged into responses even when compression is skipped, preventing caches from mixing encoded and unencoded variants.
### CSRF
The `Expiration` field in the CSRF middleware configuration has been renamed to `IdleTimeout` to better describe its functionality. Additionally, the default value has been reduced from 1 hour to 30 minutes.
CSRF now redacts tokens and storage keys by default and exposes a `DisableValueRedaction` toggle (default `false`) if you must surface those values in diagnostics.
The CSRF middleware now validates the [`Sec-Fetch-Site`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site) header for unsafe HTTP methods. When present, requests with invalid `Sec-Fetch-Site` values (not one of "same-origin", "none", "same-site", or "cross-site") are rejected with `ErrFetchSiteInvalid`. Valid or absent headers proceed to standard origin and token validation checks, providing an early gate to catch malformed requests while maintaining compatibility with legitimate cross-site traffic.
### Idempotency
Idempotency middleware now redacts keys by default and offers a `DisableValueRedaction` configuration flag (default `false`) to expose them when debugging.
### EncryptCookie
- Added support for specifying key length when using `encryptcookie.GenerateKey(length)`. Keys must be base64-encoded and may be 16, 24, or 32 bytes when decoded, supporting AES-128, AES-192, and AES-256 (default).
- Custom encryptor and decryptor callbacks now receive the cookie name. The default AES-GCM helpers bind it as additional authenticated data (AAD) so ciphertext cannot be replayed under a different cookie.
- **Breaking change:** Custom encryptor/decryptor hooks now accept the cookie name as their first argument. Update overrides like:
```go
// Before
Encryptor func(value, key string) (string, error)
Decryptor func(value, key string) (string, error)
// After
Encryptor func(name, value, key string) (string, error)
Decryptor func(name, value, key string) (string, error)
```
### Favicon
The favicon middleware now caps cached favicon assets with a configurable `MaxBytes` limit (default `1 MiB`) and uses a limited reader to guard against oversized files when loading from disk.
### EnvVar
The `ExcludeVars` field has been removed from the EnvVar middleware configuration. When upgrading, remove any references to this field and explicitly list the variables you wish to expose using `ExportVars`.
### Filesystem
The filesystem middleware was removed to reduce confusion with the static middleware.
The static middleware now covers the functionality of both. Review the [static middleware](./middleware/static.md) docs or the [migration guide](#-migration-guide) for the updated usage.
### Healthcheck
The healthcheck middleware has been simplified into a single generic probe handler. No endpoints are registered automatically. Register the middleware on each route you need—using helpers like `healthcheck.LivenessEndpoint`, `healthcheck.ReadinessEndpoint`, or `healthcheck.StartupEndpoint`—and optionally supply a `Probe` function to determine the service's health. This approach lets you expose any number of health check routes.
Refer to the [healthcheck middleware migration guide](./middleware/healthcheck.md) or the [general migration guide](#-migration-guide) to review the changes.
### KeyAuth
The keyauth middleware was updated to introduce a configurable `Realm` field for the `WWW-Authenticate` header.
The old string-based `KeyLookup` configuration has been replaced with an `Extractor` field. Use helper functions like `keyauth.FromHeader`, `keyauth.FromAuthHeader`, or `keyauth.FromCookie` to define where the key should be retrieved from. Multiple sources can be combined with `keyauth.Chain`. See the migration guide below.
New `Challenge`, `Error`, `ErrorDescription`, `ErrorURI`, and `Scope` fields allow customizing the `WWW-Authenticate` header, returning Bearer error details, and specifying required scopes. `ErrorURI` values are validated as absolute, a default `ApiKey` challenge is emitted when using non-Authorization extractors, Bearer `error` values are validated, credentials must conform to RFC 7235 `token68` syntax, and `scope` values are checked against RFC 6750's `scope-token` format. The header is also emitted only after the status code is finalized.
### Logger
New helper function called `LoggerToWriter` has been added to the logger middleware. This function allows you to use 3rd party loggers such as `logrus` or `zap` with the Fiber logger middleware without an extra adapter. For example, you can use `zap` with Fiber logger middleware like this:
Logger configuration now uses `Stream` instead of `Output` for the destination writer, so update your logger middleware configuration when migrating to v3.
Custom logger integrations should update any `LoggerFunc` implementations to the new signature that receives a pointer to the middleware config: `func(c fiber.Ctx, data *logger.Data, cfg *logger.Config) error`.
Example
```go
package main
import (
"github.com/gofiber/contrib/fiberzap/v2"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/fiber/v3/middleware/logger"
)
func main() {
// Create a new Fiber instance
app := fiber.New()
// Create a new zap logger which is compatible with Fiber AllLogger interface
zap := fiberzap.NewLogger(fiberzap.LoggerConfig{
ExtraKeys: []string{"request_id"},
})
// Use the logger middleware with zerolog logger
app.Use(logger.New(logger.Config{
Stream: logger.LoggerToWriter(zap, log.LevelDebug),
}))
// Define a route
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
// Start server on http://localhost:3000
app.Listen(":3000")
}
```
:::note
The deprecated `TagHeader` constant was removed. Use `TagReqHeader` when you need to log request headers.
:::
#### Logging Middleware Values (e.g., Request ID)
In Fiber v3, middleware (like `requestid`) now stores values in the request context using unexported keys of custom types. This aligns with Go's context best practices to prevent key collisions between packages.
As a result, directly accessing these values using string keys with `c.Locals("your_key")` or in the logger format string with `${locals:your_key}` (e.g., `${locals:requestid}`) will no longer work for values set by such middleware.
**Recommended Solution: `CustomTags`**
The cleanest and most maintainable way to include these middleware-specific values in your logs is by using the `CustomTags` option in the logger middleware configuration. This allows you to define a custom function to retrieve the value correctly from the context.
Example: Logging Request ID with CustomTags
```go
package main
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/logger"
"github.com/gofiber/fiber/v3/middleware/requestid"
)
func main() {
app := fiber.New()
// Ensure requestid middleware is used before the logger
app.Use(requestid.New())
app.Use(logger.New(logger.Config{
CustomTags: map[string]logger.LogFunc{
"requestid": func(output logger.Buffer, c fiber.Ctx, data *logger.Data, extraParam string) (int, error) {
// Retrieve the request ID using the middleware's specific function
return output.WriteString(requestid.FromContext(c))
},
},
// Use the custom tag in your format string
Format: "[${time}] ${ip} - ${requestid} - ${status} ${method} ${path}\n",
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
app.Listen(":3000")
}
```
**Alternative: Manually Copying to `Locals`**
If you have existing logging patterns that rely on `c.Locals` or prefer to manage these values in `Locals` for other reasons, you can manually copy the value from the context to `c.Locals` in a preceding middleware:
Example: Manually setting requestid in Locals
```go
app.Use(requestid.New()) // Request ID middleware
app.Use(func(c fiber.Ctx) error {
// Manually copy the request ID to Locals
c.Locals("requestid", requestid.FromContext(c))
return c.Next()
})
app.Use(logger.New(logger.Config{
// Now ${locals:requestid} can be used, but CustomTags is generally preferred
Format: "[${time}] ${ip} - ${locals:requestid} - ${status} ${method} ${path}\n",
}))
```
Both approaches ensure your logger can access these values while respecting Go's context practices.
The `Skip` is a function to determine if logging is skipped or written to `Stream`.
Example Usage
```go
app.Use(logger.New(logger.Config{
Skip: func(c fiber.Ctx) bool {
// Skip logging HTTP 200 requests
return c.Response().StatusCode() == fiber.StatusOK
},
}))
```
```go
app.Use(logger.New(logger.Config{
Skip: func(c fiber.Ctx) bool {
// Only log errors, similar to an error.log
return c.Response().StatusCode() < 400
},
}))
```
#### Predefined Formats
Logger provides predefined formats that you can use by name or directly by specifying the format string.
Example Usage
```go
app.Use(logger.New(logger.Config{
Format: logger.FormatCombined,
}))
```
See more in [Logger](./middleware/logger.md#predefined-formats)
### Limiter
The limiter middleware uses a new Fixed Window Rate Limiter implementation.
Custom limiter algorithms should now implement the updated `limiter.Handler` interface, whose `New` method receives a pointer to the active config: `New(cfg *limiter.Config) fiber.Handler`.
Limiter now redacts request keys in error paths by default. A new `DisableValueRedaction` boolean (default `false`) lets you reveal the raw limiter key if diagnostics require it.
:::note
Deprecated fields `Duration`, `Store`, and `Key` have been removed in v3. Use `Expiration`, `Storage`, and `KeyGenerator` instead.
:::
### Monitor
Monitor middleware is migrated to the [Contrib package](https://github.com/gofiber/contrib/tree/main/monitor) with [PR #1172](https://github.com/gofiber/contrib/pull/1172).
### Proxy
The proxy middleware has been updated to improve consistency with Go naming conventions. The `TlsConfig` field in the configuration struct has been renamed to `TLSConfig`. Additionally, the `WithTlsConfig` method has been removed; you should now configure TLS directly via the `TLSConfig` property within the `Config` struct.
The new `KeepConnectionHeader` option (default `false`) drops the `Connection` header unless explicitly enabled to retain it.
`proxy.Balancer` now accepts an optional variadic configuration: call `proxy.Balancer()` to use defaults or continue passing a `proxy.Config` value as before.
### Recover
The Recover middleware allows customizing the error it returns. Set a `PanicHandler` in its `Config` to change the default behavior.
### Session
The Session middleware has undergone key changes in v3 to improve functionality and flexibility. While v2 methods remain available for backward compatibility, we now recommend using the new middleware handler for session management.
#### Key Updates
The session middleware has undergone significant improvements in v3, focusing on type safety, flexibility, and better developer experience.
#### Key Changes
- **Extractor Pattern**: The string-based `KeyLookup` configuration has been replaced with a more flexible and type-safe `Extractor` function pattern.
- **New Middleware Handler**: The `New` function now returns a middleware handler instead of a `*Store`. To access the session store, use the `Store` method on the middleware, or opt for `NewStore` or `NewWithStore` for custom store integration.
- **Manual Session Release**: Session instances are no longer automatically released after being saved. To ensure proper lifecycle management, you must manually call `sess.Release()`.
- **Idle Timeout**: The `Expiration` field has been replaced with `IdleTimeout`, which handles session inactivity. If the session is idle for the specified duration, it will expire. The idle timeout is updated when the session is saved. If you are using the middleware handler, the idle timeout will be updated automatically.
- **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. The session will expire after the specified duration, regardless of activity.
- **Default KeyGenerator**: Changed from `utils.UUIDv4` to `utils.SecureToken`, producing base64-encoded tokens instead of UUID format.
For more details on these changes and migration instructions, check the [Session Middleware Migration Guide](./middleware/session.md#migration-guide).
### Timeout
The timeout middleware is now configurable. A new `Config` struct allows customizing the timeout duration, defining a handler that runs when a timeout occurs, and specifying errors to treat as timeouts. The `New` function now accepts a `Config` value instead of a duration.
**Behavioral changes:**
- **Immediate return on timeout**: The middleware now returns immediately when a timeout occurs, without waiting for the handler to finish. This is achieved through the new **Abandon mechanism** which marks the context as abandoned so it won't be returned to the pool while the handler is still running.
- **Context propagation**: The timeout context is properly propagated to the handler. Handlers can detect timeouts by listening on `c.Context().Done()` and return early.
- **Panic handling**: Panics in the handler are caught and converted to `500 Internal Server Error` responses.
- **Race-free design**: The implementation uses fasthttp's `TimeoutErrorWithCode` combined with Fiber's Abandon mechanism to ensure complete race-freedom between the middleware, handler goroutine, and context pooling.
**New Ctx methods for the Abandon mechanism:**
- `Abandon()`: Marks the context as abandoned
- `IsAbandoned()`: Returns true if the context was abandoned
- `ForceRelease()`: Releases an abandoned context back to the pool (for advanced use)
**Migration:** Replace calls like `timeout.New(handler, 2*time.Second)` with `timeout.New(handler, timeout.Config{Timeout: 2 * time.Second})`.
## 🔌 Addons
In v3, Fiber introduced Addons. Addons are additional useful packages that can be used in Fiber.
### Retry
The Retry addon is a new addon that implements a retry mechanism for unsuccessful network operations. It uses an exponential backoff algorithm with jitter.
It calls the function multiple times and tries to make it successful. If all calls are failed, then, it returns an error.
It adds a jitter at each retry step because adding a jitter is a way to break synchronization across the client and avoid collision.
Example
```go
package main
import (
"fmt"
"github.com/gofiber/fiber/v3/addon/retry"
"github.com/gofiber/fiber/v3/client"
)
func main() {
expBackoff := retry.NewExponentialBackoff(retry.Config{})
// Local variables that will be used inside of Retry
var resp *client.Response
var err error
// Retry a network request and return an error to signify to try again
err = expBackoff.Retry(func() error {
client := client.New()
resp, err = client.Get("https://gofiber.io")
if err != nil {
return fmt.Errorf("GET gofiber.io failed: %w", err)
}
if resp.StatusCode() != 200 {
return fmt.Errorf("GET gofiber.io did not return OK 200")
}
return nil
})
// If all retries failed, panic
if err != nil {
panic(err)
}
fmt.Printf("GET gofiber.io succeeded with status code %d\n", resp.StatusCode())
}
```
## 📋 Migration guide
To streamline upgrades between Fiber versions, the Fiber CLI ships with a
`migrate` command:
```bash
go install github.com/gofiber/cli/fiber@latest
fiber migrate --to v3
```
### Options
- `-t, --to string` migrate to a specific version, e.g. `v3.0.0`
- `-f, --force` force migration even if already on that version
- `-s, --skip_go_mod` skip running `go mod tidy`, `go mod download`, and `go mod vendor`
### Changes Overview
- [🚀 App](#-app-1)
- [🎣 Hooks](#-hooks-1)
- [🚀 Listen](#-listen-1)
- [🗺 Router](#-router-1)
- [🧠 Context](#-context-1)
- [📎 Binding (was Parser)](#-parser)
- [🔄 Redirect](#-redirect-1)
- [🧾 Log](#-log-1)
- [🌎 Client package](#-client-package-1)
- [🛠️ Utils](#utils-migration)
- [🧬 Middlewares](#-middlewares-1)
- [Important Change for Accessing Middleware Data](#important-change-for-accessing-middleware-data)
- [BasicAuth](#basicauth-1)
- [Cache](#cache-1)
- [CORS](#cors-1)
- [CSRF](#csrf-1)
- [Filesystem](#filesystem-1)
- [EnvVar](#envvar-1)
- [Favicon](#favicon)
- [Healthcheck](#healthcheck-1)
- [Monitor](#monitor-1)
- [Proxy](#proxy-1)
- [Session](#session-1)
### 🚀 App
#### Static
Since we've removed `app.Static()`, you need to move methods to static middleware like the example below:
```go
// Before
app.Static("/", "./public")
app.Static("/prefix", "./public")
app.Static("/prefix", "./public", Static{
Index: "index.htm",
})
app.Static("*", "./public/index.html")
```
```go
// After
app.Get("/*", static.New("./public"))
app.Get("/prefix*", static.New("./public"))
app.Get("/prefix*", static.New("./public", static.Config{
IndexNames: []string{"index.htm", "index.html"},
}))
app.Get("*", static.New("./public/index.html"))
```
:::caution
You have to put `*` to the end of the route if you don't define static route with `app.Use`.
:::
#### Trusted Proxies
We've renamed `EnableTrustedProxyCheck` to `TrustProxy` and moved `TrustedProxies` to `TrustProxyConfig`.
**Important:** To use proxy headers like `X-Forwarded-For` with `c.IP()`, you must configure **all** of `TrustProxy`, `ProxyHeader`, and a trusted proxy via `TrustProxyConfig`. If the proxy is not trusted (for example, if you set only `ProxyHeader` or only `TrustProxy` without configuring `TrustProxyConfig`), proxy headers are ignored and `c.IP()` will return the remote TCP IP instead.
```go
// Before
app := fiber.New(fiber.Config{
// EnableTrustedProxyCheck enables the trusted proxy check.
EnableTrustedProxyCheck: true,
// TrustedProxies is a list of trusted proxy IP ranges/addresses.
TrustedProxies: []string{"0.8.0.0", "127.0.0.0/8", "::1/128"},
})
```
```go
// After
app := fiber.New(fiber.Config{
// TrustProxy enables the trusted proxy check
TrustProxy: true,
// ProxyHeader specifies which header to read the real client IP from
ProxyHeader: fiber.HeaderXForwardedFor,
// TrustProxyConfig allows for configuring trusted proxies.
TrustProxyConfig: fiber.TrustProxyConfig{
// Proxies is a list of trusted proxy IP ranges/addresses.
Proxies: []string{"0.8.0.0"},
// Trust all loop-back IP addresses (127.0.0.0/8, ::1/128)
Loopback: true,
// Trust Unix domain socket connections
UnixSocket: true,
},
})
```
For detailed proxy configuration guidance, see the [reverse proxy guide](./guide/reverse-proxy.md).
### 🎣 Hooks
`OnShutdown` has been replaced by two hooks: `OnPreShutdown` and `OnPostShutdown`.
Use them to run cleanup code before and after the server shuts down. When handling
shutdown errors, register an `OnPostShutdown` hook and call `app.Listen()` in a goroutine.
```go
// Before
app.OnShutdown(func() {
// Code to run before shutdown
})
```
```go
// After
app.Hooks().OnPreShutdown(func() error {
// Code to run before shutdown
return nil
})
```
### 🚀 Listen
The `Listen` helpers (`ListenTLS`, `ListenMutualTLS`, etc.) were removed. Use
`app.Listen()` with `fiber.ListenConfig` and a `tls.Config` when TLS is required.
Options such as `ListenerNetwork` and `UnixSocketFileMode` are now configured via
this struct. Prefer `TLSConfig` when you need full control, or use `CertFile` and
`CertKeyFile` for quick TLS setup.
```go
// Before
app.ListenTLS(":3000", "cert.pem", "key.pem")
```
```go
// After
app.Listen(":3000", fiber.ListenConfig{
CertFile: "./cert.pem",
CertKeyFile: "./key.pem",
})
```
### 🗺 Router
#### Direct `net/http` handlers
Route registration helpers now accept native `net/http` handlers. Pass an
`http.Handler`, `http.HandlerFunc`, or compatible function directly to methods
such as `app.Get`, `Group`, or `RouteChain` and Fiber will adapt it at
registration time. Manual wrapping through the adaptor middleware is no longer
required for these common cases.
:::note Compatibility considerations
Adapted handlers stick to `net/http` semantics. They do not interact with `fiber.Ctx`
and are slower than native Fiber handlers because of the extra conversion layer. Use
them to ease migrations, but prefer Fiber handlers in performance-critical paths.
:::
```go
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := w.Write([]byte("served by net/http")); err != nil {
panic(err)
}
})
app.Get("/", httpHandler)
```
#### Middleware Registration
The signatures for [`Add`](#middleware-registration) and [`Route`](#route-chaining) have been changed.
To migrate [`Add`](#middleware-registration) you must change the `methods` in a slice.
```go
// Before
app.Add(fiber.MethodPost, "/api", myHandler)
```
```go
// After
app.Add([]string{fiber.MethodPost}, "/api", myHandler)
```
#### Mounting
In this release, the `Mount` method has been removed. Instead, you can use the `Use` method to achieve similar functionality.
```go
// Before
app.Mount("/api", apiApp)
```
```go
// After
app.Use("/api", apiApp)
```
#### Route Chaining
Refer to the [route chaining](#route-chaining) section for details on the new `RouteChain` helper. The `Route` function now matches its v2 behavior for prefix encapsulation.
```go
// Before
app.Route("/api", func(apiGrp Router) {
apiGrp.Route("/user/:id?", func(userGrp Router) {
userGrp.Get("/", func(c fiber.Ctx) error {
// Get user
return c.JSON(fiber.Map{"message": "Get user", "id": c.Params("id")})
})
userGrp.Post("/", func(c fiber.Ctx) error {
// Create user
return c.JSON(fiber.Map{"message": "User created"})
})
})
})
```
```go
// After
app.RouteChain("/api").RouteChain("/user/:id?")
.Get(func(c fiber.Ctx) error {
// Get user
return c.JSON(fiber.Map{"message": "Get user", "id": c.Params("id")})
})
.Post(func(c fiber.Ctx) error {
// Create user
return c.JSON(fiber.Map{"message": "User created"})
});
```
### 🗺 RebuildTree
We introduced a new method that enables rebuilding the route tree stack at runtime. This allows you to add routes dynamically while your application is running and update the route tree to make the new routes available for use.
For more details, refer to the [app documentation](./api/app.md#rebuildtree):
#### Example Usage
```go
app.Get("/define", func(c fiber.Ctx) error { // Define a new route dynamically
app.Get("/dynamically-defined", func(c fiber.Ctx) error { // Adding a dynamically defined route
return c.SendStatus(http.StatusOK)
})
app.RebuildTree() // Rebuild the route tree to register the new route
return c.SendStatus(http.StatusOK)
})
```
In this example, a new route is defined, and `RebuildTree()` is called to ensure the new route is registered and available.
Note: Use this method with caution. It is **not** thread-safe and can be very performance-intensive. Therefore, it should be used sparingly and primarily in development mode. It should not be invoke concurrently.
#### RemoveRoute
- **RemoveRoute**: Removes route by path
- **RemoveRouteByName**: Removes route by name
- **RemoveRouteFunc**: Removes route by a function having `*Route` parameter
For more details, refer to the [app documentation](./api/app.md#removeroute):
### 🧠 Context
Fiber v3 introduces several new features and changes to the Ctx interface, enhancing its functionality and flexibility.
- **ParamsInt**: Use `Params` with generic types.
- **QueryBool**: Use `Query` with generic types.
- **QueryFloat**: Use `Query` with generic types.
- **QueryInt**: Use `Query` with generic types.
- **Bind**: Now used for binding instead of view binding. Use `c.ViewBind()` for view binding.
In Fiber v3, the `Ctx` parameter in handlers is now an interface, which means the `*` symbol is no longer used. Here is an example demonstrating this change:
Example
**Before**:
```go
package main
import (
"github.com/gofiber/fiber/v2"
)
func main() {
app := fiber.New()
// Route Handler with *fiber.Ctx
app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, World!")
})
app.Listen(":3000")
}
```
**After**:
```go
package main
import (
"github.com/gofiber/fiber/v3"
)
func main() {
app := fiber.New()
// Route Handler without *fiber.Ctx
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
app.Listen(":3000")
}
```
**Explanation**:
In this example, the `Ctx` parameter in the handler is used as an interface (`fiber.Ctx`) instead of a pointer (`*fiber.Ctx`). This change allows for more flexibility and customization in Fiber v3.
#### 📎 Parser
The `Parser` section in Fiber v3 has undergone significant changes to improve functionality and flexibility.
##### Migration Instructions
1. **BodyParser**: Use `c.Bind().Body()` instead of `c.BodyParser()`.
Example
```go
// Before
app.Post("/user", func(c *fiber.Ctx) error {
var user User
if err := c.BodyParser(&user); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(user)
})
```
```go
// After
app.Post("/user", func(c fiber.Ctx) error {
var user User
if err := c.Bind().Body(&user); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(user)
})
```
2. **ParamsParser**: Use `c.Bind().URI()` instead of `c.ParamsParser()`. Note that the struct tag has changed from `params` to `uri`.
Example
```go
// Before
type Params struct {
ID int `params:"id"`
}
app.Get("/user/:id", func(c *fiber.Ctx) error {
var params Params
if err := c.ParamsParser(¶ms); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(params)
})
```
```go
// After
type Params struct {
ID int `uri:"id"`
}
app.Get("/user/:id", func(c fiber.Ctx) error {
var params Params
if err := c.Bind().URI(¶ms); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(params)
})
```
3. **QueryParser**: Use `c.Bind().Query()` instead of `c.QueryParser()`.
Example
```go
// Before
app.Get("/search", func(c *fiber.Ctx) error {
var query Query
if err := c.QueryParser(&query); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(query)
})
```
```go
// After
app.Get("/search", func(c fiber.Ctx) error {
var query Query
if err := c.Bind().Query(&query); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(query)
})
```
4. **CookieParser**: Use `c.Bind().Cookie()` instead of `c.CookieParser()`.
Example
```go
// Before
app.Get("/cookie", func(c *fiber.Ctx) error {
var cookie Cookie
if err := c.CookieParser(&cookie); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(cookie)
})
```
```go
// After
app.Get("/cookie", func(c fiber.Ctx) error {
var cookie Cookie
if err := c.Bind().Cookie(&cookie); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()})
}
return c.JSON(cookie)
})
```
#### 🔄 Redirect
Fiber v3 enhances the redirect functionality by introducing new methods and improving existing ones. The new redirect methods provide more flexibility and control over the redirection process.
##### Migration Instructions
1. **RedirectToRoute**: Use `c.Redirect().Route()` instead of `c.RedirectToRoute()`.
Example
```go
// Before
app.Get("/old", func(c *fiber.Ctx) error {
return c.RedirectToRoute("newRoute")
})
```
```go
// After
app.Get("/old", func(c fiber.Ctx) error {
return c.Redirect().Route("newRoute")
})
```
2. **RedirectBack**: Use `c.Redirect().Back()` instead of `c.RedirectBack()`.
Example
```go
// Before
app.Get("/back", func(c *fiber.Ctx) error {
return c.RedirectBack()
})
```
```go
// After
app.Get("/back", func(c fiber.Ctx) error {
return c.Redirect().Back()
})
```
3. **Redirect**: Use `c.Redirect().To()` instead of `c.Redirect()`.
Example
```go
// Before
app.Get("/old", func(c *fiber.Ctx) error {
return c.Redirect("/new")
})
```
```go
// After
app.Get("/old", func(c fiber.Ctx) error {
return c.Redirect().To("/new")
})
```
#### 🧾 Log
The `ConfigurableLogger` and `AllLogger` interfaces now use generics. You can specify the underlying logger type when implementing these interfaces. While `any` can be used for maximum flexibility in some contexts, when retrieving the concrete logger via `log.DefaultLogger`, you must specify the exact underlying logger type, for example `log.DefaultLogger[*MyLogger]().Logger()`.
### 🌎 Client package
Fiber v3 introduces a completely rebuilt client package with numerous new features such as Cookiejar, request/response hooks, and more. Here is a guide to help you migrate from Fiber v2 to Fiber v3.
#### New Features
- **Cookiejar**: Manage cookies automatically.
- **Request/Response Hooks**: Customize request and response handling.
- **Improved Error Handling**: Better error management and reporting.
#### Migration Instructions
**Import Path**:
Update the import path to the new client package.
Before
```go
import "github.com/gofiber/fiber/v2/client"
```
After
```go
import "github.com/gofiber/fiber/v3/client"
```
**Common migrations**:
1. **Shared defaults instead of per-call mutation**: Move headers and timeouts into the reusable client and override with `client.Config` when needed.
Example
```go
// Before
status, body, errs := fiber.Get("https://api.example.com/users").
Set("Authorization", "Bearer "+token).
Timeout(5 * time.Second).
String()
if len(errs) > 0 {
return fmt.Errorf("request failed: %v", errs)
}
fmt.Println(status, body)
```
```go
// After
cli := client.New().
AddHeader("Authorization", "Bearer "+token).
SetTimeout(5 * time.Second)
resp, err := cli.Get("https://api.example.com/users")
if err != nil {
return err
}
defer resp.Close()
fmt.Println(resp.StatusCode(), resp.String())
```
2. **Body handling**: Replace `Agent.JSON(...).Struct(&dst)` with request bodies through `client.Config` (or `Request.SetJSON`) and decode the response via `Response.JSON`.
Example
```go
// Before
var created user
status, _, errs := fiber.Post("https://api.example.com/users").
JSON(payload).
Struct(&created)
if len(errs) > 0 {
return fmt.Errorf("request failed: %v", errs)
}
fmt.Println(status, created)
```
```go
// After
cli := client.New()
resp, err := cli.Post("https://api.example.com/users", client.Config{
Body: payload,
})
if err != nil {
return err
}
defer resp.Close()
var created user
if err := resp.JSON(&created); err != nil {
return fmt.Errorf("decode failed: %w", err)
}
fmt.Println(resp.StatusCode(), created)
```
3. **Path and query parameters**: Use the new path/query helpers instead of manually formatting URLs.
Example
```go
// Before
code, body, errs := fiber.Get(fmt.Sprintf("https://api.example.com/users/%s", id)).
QueryString("active=true").
String()
if len(errs) > 0 {
return fmt.Errorf("request failed: %v", errs)
}
fmt.Println(code, body)
```
```go
// After
cli := client.New().SetBaseURL("https://api.example.com")
resp, err := cli.Get("/users/:id", client.Config{
PathParam: map[string]string{"id": id},
Param: map[string]string{"active": "true"},
})
if err != nil {
return err
}
defer resp.Close()
fmt.Println(resp.StatusCode(), resp.String())
```
4. **Agent helpers**: `Agent.Bytes`, `AcquireAgent`, and `Agent.Parse` have been removed. Reuse a `client.Client` instance (or pool requests/responses directly) and access response data through the new typed helpers.
Example
```go
// Before
agent := fiber.AcquireAgent()
status, body, errs := agent.Get("https://api.example.com/users").Bytes()
fiber.ReleaseAgent(agent)
if len(errs) > 0 {
return fmt.Errorf("request failed: %v", errs)
}
var users []user
if err := fiber.Parse(body, &users); err != nil {
return fmt.Errorf("parse failed: %w", err)
}
fmt.Println(status, len(users))
```
```go
// After
cli := client.New()
resp, err := cli.Get("https://api.example.com/users")
if err != nil {
return err
}
defer resp.Close()
var users []user
if err := resp.JSON(&users); err != nil {
return fmt.Errorf("decode failed: %w", err)
}
fmt.Println(resp.StatusCode(), len(users))
```
:::tip
If you need pooling, use `client.AcquireRequest`, `client.AcquireResponse`, and their corresponding release functions around a long-lived `client.Client` instead of the removed agent pool.
:::
5. **Fiber-level shortcuts**: The `fiber.Get`, `fiber.Post`, and similar top-level helpers are no longer exposed from the main module. Use the client package equivalents (`client.Get`, `client.Post`, etc.) which call the shared default client (or pass your own client instance for custom defaults).
Example
```go
// Before
status, body, errs := fiber.Get("https://api.example.com/health").String()
if len(errs) > 0 {
return fmt.Errorf("request failed: %v", errs)
}
fmt.Println(status, body)
```
```go
// After
resp, err := client.Get("https://api.example.com/health")
if err != nil {
return err
}
defer resp.Close()
fmt.Println(resp.StatusCode(), resp.String())
```
:::note
The `client.Get`/`client.Post` helpers use `client.C()` (the default shared client). For custom defaults, construct a client with `client.New()` and invoke its methods instead.
:::
#### Complete API Migration Reference
Click to expand full v2 → v3 API mapping tables
##### Core Concepts
| Description | v2 | v3 |
|-------------|----|----|
| Import | `github.com/gofiber/fiber/v2` | `github.com/gofiber/fiber/v3/client` |
| Client Concept | `*fiber.Agent` | `*client.Client` + `*client.Request` |
| Response Concept | `(code int, body []byte, errs []error)` | `(*client.Response, error)` |
##### Client/Agent Creation
| Description | v2 | v3 |
|-------------|----|----|
| Create Agent/Client | `fiber.AcquireAgent()` | `client.New()` |
| Get from pool | `fiber.AcquireAgent()` | `client.AcquireRequest()` |
| Release | `fiber.ReleaseAgent(a)` | `client.ReleaseRequest(req)` |
| With fasthttp.Client | - | `client.NewWithClient(c)` |
| With HostClient | - | `client.NewWithHostClient(hc)` |
| With LBClient | - | `client.NewWithLBClient(lb)` |
| Get Request object | `a.Request()` | `c.R()` |
| Default client | - | `client.C()` |
| Replace default | - | `client.Replace(c)` |
##### HTTP Methods
| Description | v2 | v3 (Client) | v3 (Request) |
|-------------|----|----|--------------|
| GET | `fiber.Get(url)` | `c.Get(url, cfg...)` | `req.Get(url)` |
| POST | `fiber.Post(url)` | `c.Post(url, cfg...)` | `req.Post(url)` |
| PUT | `fiber.Put(url)` | `c.Put(url, cfg...)` | `req.Put(url)` |
| PATCH | `fiber.Patch(url)` | `c.Patch(url, cfg...)` | `req.Patch(url)` |
| DELETE | `fiber.Delete(url)` | `c.Delete(url, cfg...)` | `req.Delete(url)` |
| HEAD | `fiber.Head(url)` | `c.Head(url, cfg...)` | `req.Head(url)` |
| OPTIONS | - | `c.Options(url, cfg...)` | `req.Options(url)` |
| Custom | - | `c.Custom(url, method, cfg...)` | `req.Custom(url, method)` |
##### URL & Method
| Description | v2 | v3 |
|-------------|----|----|
| Set URL | `req.SetRequestURI(url)` | `req.SetURL(url)` |
| Get URL | `req.URI().String()` | `req.URL()` |
| Set Method | `req.Header.SetMethod(method)` | `req.SetMethod(method)` |
| Set Base URL | - | `c.SetBaseURL(url)` |
##### Request Execution & Response
| Description | v2 | v3 |
|-------------|----|----|
| Parse Request | `a.Parse()` | Not needed |
| Execute (bytes) | `a.Bytes()` → `(code, body, errs)` | `req.Send()` → `(*Response, error)` |
| Execute (string) | `a.String()` | `resp.String()` |
| Execute (struct) | `a.Struct(&v)` | `resp.JSON(&v)` / `resp.XML(&v)` |
| Status Code | Return value `code` | `resp.StatusCode()` |
| Status Text | - | `resp.Status()` |
| Body (bytes) | Return value `body` | `resp.Body()` |
| Response Header | `resp.Header.Peek(key)` | `resp.Header(key)` |
| All Headers | `resp.Header.VisitAll(fn)` | `resp.Headers()` |
| Cookies | - | `resp.Cookies()` |
| Save to file | - | `resp.Save(path)` |
| Close | - | `resp.Close()` |
##### Headers
| Description | v2 | v3 (Client) | v3 (Request) |
|-------------|----|----|--------------|
| Set Header | `a.Set(k, v)` | `c.SetHeader(k, v)` | `req.SetHeader(k, v)` |
| Add Header | `a.Add(k, v)` | `c.AddHeader(k, v)` | `req.AddHeader(k, v)` |
| Multiple Headers | - | `c.SetHeaders(map)` | `req.SetHeaders(map)` |
| Bytes variants | `a.SetBytesK/V/KV()` | - | - |
##### User-Agent, Referer, Content-Type, Host
| Description | v2 | v3 (Client) | v3 (Request) |
|-------------|----|----|--------------|
| User-Agent | `a.UserAgent(ua)` | `c.SetUserAgent(ua)` | `req.SetUserAgent(ua)` |
| Referer | `a.Referer(ref)` | `c.SetReferer(ref)` | `req.SetReferer(ref)` |
| Content-Type | `a.ContentType(ct)` | - | `req.SetHeader("Content-Type", ct)` |
| Host | `a.Host(host)` | - | `req.SetHeader("Host", host)` |
| Connection Close | `a.ConnectionClose()` | - | `req.SetHeader("Connection", "close")` |
##### Cookies
| Description | v2 | v3 (Client) | v3 (Request) |
|-------------|----|----|--------------|
| Set Cookie | `a.Cookie(k, v)` | `c.SetCookie(k, v)` | `req.SetCookie(k, v)` |
| Multiple | `a.Cookies(k1, v1, ...)` | `c.SetCookies(map)` | `req.SetCookies(map)` |
| With Struct | - | `c.SetCookiesWithStruct(v)` | `req.SetCookiesWithStruct(v)` |
| Cookie Jar | - | `c.SetCookieJar(jar)` | - |
##### Query Parameters
| Description | v2 | v3 (Client) | v3 (Request) |
|-------------|----|----|--------------|
| Query String | `a.QueryString(qs)` | - | - |
| Add Param | - | `c.AddParam(k, v)` | `req.AddParam(k, v)` |
| Set Param | - | `c.SetParam(k, v)` | `req.SetParam(k, v)` |
| With Struct | - | `c.SetParamsWithStruct(v)` | `req.SetParamsWithStruct(v)` |
##### Path Parameters (NEW)
| Description | v2 | v3 (Client) | v3 (Request) |
|-------------|----|----|--------------|
| Set Path Param | - | `c.SetPathParam(k, v)` | `req.SetPathParam(k, v)` |
| Multiple | - | `c.SetPathParams(map)` | `req.SetPathParams(map)` |
| With Struct | - | `c.SetPathParamsWithStruct(v)` | `req.SetPathParamsWithStruct(v)` |
##### Request Body
| Description | v2 | v3 |
|-------------|----|----|
| Body (bytes) | `a.Body(body)` | `req.SetRawBody(body)` |
| Body (string) | `a.BodyString(body)` | `req.SetRawBody([]byte(body))` |
| Body Stream | `a.BodyStream(r, size)` | - |
| JSON | `a.JSON(v)` | `req.SetJSON(v)` |
| XML | `a.XML(v)` | `req.SetXML(v)` |
| CBOR (NEW) | - | `req.SetCBOR(v)` |
##### Form Data
| Description | v2 | v3 |
|-------------|----|----|
| Create Args | `fiber.AcquireArgs()` | Direct on Request |
| Send Form | `a.Form(args)` | `req.SetFormData(k, v)` |
| Add Form Data | `args.Set(k, v)` | `req.AddFormData(k, v)` |
| With Map | - | `req.SetFormDataWithMap(map)` |
| With Struct | - | `req.SetFormDataWithStruct(v)` |
##### File Upload
| Description | v2 | v3 |
|-------------|----|----|
| Multipart Form | `a.MultipartForm(args)` | Automatic |
| Boundary | `a.Boundary(b)` | `req.SetBoundary(b)` |
| Send File | `a.SendFile(f, field...)` | `req.AddFile(path)` |
| Multiple Files | `a.SendFiles(...)` | `req.AddFiles(files...)` |
| With Reader | - | `req.AddFileWithReader(name, r)` |
| FileData | `a.FileData(files...)` | `req.AddFiles(files...)` |
##### Timeout & TLS
| Description | v2 | v3 (Client) | v3 (Request) |
|-------------|----|----|--------------|
| Timeout | `a.Timeout(d)` | `c.SetTimeout(d)` | `req.SetTimeout(d)` |
| Max Redirects | `a.MaxRedirectsCount(n)` | Via Config | `req.SetMaxRedirects(n)` |
| TLS Config | `a.TLSConfig(cfg)` | `c.SetTLSConfig(cfg)` | - |
| Skip Verify | `a.InsecureSkipVerify()` | Via `tls.Config` | - |
| Certificates | - | `c.SetCertificates(...)` | - |
| Root Cert | - | `c.SetRootCertificate(path)` | - |
##### JSON/XML Encoder
| Description | v2 | v3 |
|-------------|----|----|
| JSON Encoder | `a.JSONEncoder(fn)` | `c.SetJSONMarshal(fn)` |
| JSON Decoder | `a.JSONDecoder(fn)` | `c.SetJSONUnmarshal(fn)` |
| XML Encoder | - | `c.SetXMLMarshal(fn)` |
| XML Decoder | - | `c.SetXMLUnmarshal(fn)` |
| CBOR (NEW) | - | `c.SetCBORMarshal/Unmarshal(fn)` |
##### Authentication
| Description | v2 | v3 |
|-------------|----|----|
| Basic Auth | `a.BasicAuth(user, pass)` | Via Header (Base64) |
##### Debug & Retry
| Description | v2 | v3 |
|-------------|----|----|
| Debug | `a.Debug(w...)` | `c.Debug()` |
| Disable Debug | - | `c.DisableDebug()` |
| Logger | - | `c.SetLogger(logger)` |
| Retry | `a.RetryIf(fn)` | `c.SetRetryConfig(cfg)` |
##### Reuse & Reset
| Description | v2 | v3 |
|-------------|----|----|
| Reuse Agent | `a.Reuse()` | Use pool |
| Reset Client | - | `c.Reset()` |
| Dest Buffer | `a.Dest(dest)` | - |
##### NEW in v3
| Feature | v3 API |
|---------|--------|
| Request Hooks | `c.AddRequestHook(fn)` |
| Response Hooks | `c.AddResponseHook(fn)` |
| Proxy | `c.SetProxyURL(url)` |
| Context | `req.SetContext(ctx)` |
| Dial Function | `c.SetDial(fn)` |
| Raw Request | `req.RawRequest` |
| Raw Response | `resp.RawResponse` |
##### Key Differences
1. **Architecture**: v2 `Agent` → v3 separate `Client`, `Request`, `Response`
2. **Error Handling**: v2 `[]error` → v3 single `error`
3. **Response**: v2 tuple `(code, body, errs)` → v3 `*Response` object
4. **No Parse()**: v3 auto-initializes requests
5. **Hooks**: v3 adds request/response middleware
6. **Path Params**: v3 native `:param` support
7. **Cookie Jar**: v3 built-in session management
8. **CBOR**: v3 adds CBOR encoding
9. **Context**: v3 native cancellation support
10. **Iterators**: v3 uses `iter.Seq2` for collections
11. **Bytes variants removed**: v2 `*Bytes*` methods gone
### 🛠️ Utils {#utils-migration}
Fiber v3 removes the in-repo `utils` package in favor of the external [`github.com/gofiber/utils/v2`](https://github.com/gofiber/utils) module.
1. Replace imports:
```go
- import "github.com/gofiber/fiber/v2/utils"
+ import "github.com/gofiber/utils/v2"
```
1. Review function changes:
| v2 function | v3 replacement |
| --- | --- |
| `AssertEqual` | removed; use testing libraries like [`github.com/stretchr/testify/assert`](https://pkg.go.dev/github.com/stretchr/testify/assert) |
| `ToLowerBytes` | `utils.ToLowerBytes` |
| `ToUpperBytes` | `utils.ToUpperBytes` |
| `TrimRightBytes` | `utils.TrimRight` |
| `TrimLeftBytes` | `utils.TrimLeft` |
| `TrimBytes` | `utils.Trim` |
| `EqualFoldBytes` | `utils.EqualFold` |
| `UUID` | `utils.UUID` |
| `UUIDv4` | `utils.UUIDv4` |
| `FunctionName` | `utils.FunctionName` |
| `GetArgument` | `utils.GetArgument` |
| `IncrementIPRange` | `utils.IncrementIPRange` |
| `ConvertToBytes` | `utils.ConvertToBytes` |
| `CopyString` | `utils.CopyString` |
| `CopyBytes` | `utils.CopyBytes` |
| `ByteSize` | `utils.ByteSize` |
| `ToString` | `utils.ToString` |
| `UnsafeString` | `utils.UnsafeString` |
| `UnsafeBytes` | `utils.UnsafeBytes` |
| `GetString` | removed; use `utils.ToString` or the standard library |
| `GetBytes` | removed; use `utils.CopyBytes` or `[]byte(s)` |
| `ImmutableString` | removed; strings are already immutable |
| `GetMIME` | `utils.GetMIME` |
| `ParseVendorSpecificContentType` | `utils.ParseVendorSpecificContentType` |
| `StatusMessage` | `utils.StatusMessage` |
| `IsIPv4` | `utils.IsIPv4` |
| `IsIPv6` | `utils.IsIPv6` |
| `ToLower` | `utils.ToLower` |
| `ToUpper` | `utils.ToUpper` |
| `TrimLeft` | `strings.TrimLeft` |
| `Trim` | `strings.Trim` |
| `TrimRight` | `strings.TrimRight` |
| `EqualFold` | `strings.EqualFold` |
| `StartTimeStampUpdater` | `utils.StartTimeStampUpdater` (new `utils.Timestamp` provides the current value) |
1. Update your code. For example:
```go
// v2
import oldutils "github.com/gofiber/fiber/v2/utils"
func demo() {
b := oldutils.TrimBytes([]byte(" fiber "))
id := oldutils.UUIDv4()
s := oldutils.GetString([]byte("foo"))
}
// v3
import (
"github.com/gofiber/utils/v2"
"strings"
)
func demo() {
s := utils.TrimSpace(" fiber ")
id := utils.UUIDv4()
str := utils.ToString([]byte("foo"))
t := strings.TrimRight("bar ", " ")
}
```
The `github.com/gofiber/utils/v2` module also introduces new helpers like `ParseInt`, `ParseUint`, `Walk`, `ReadFile`, and `Timestamp`.
### 🧬 Middlewares
#### Important Change for Accessing Middleware Data
**Change:** In Fiber v2, some middlewares set data in `c.Locals()` using string keys (e.g., `c.Locals("requestid")`). In Fiber v3, to align with Go's context best practices and prevent key collisions, these middlewares now store their specific data in the request's context using unexported keys of custom types.
**Impact:** Directly accessing these middleware-provided values via `c.Locals("some_string_key")` will no longer work.
**Migration Action:**
The `ContextKey` configuration option has been removed from all middlewares. Values are no longer stored under user-defined keys. You must update your code to use the dedicated exported functions provided by each affected middleware to retrieve its data from the context.
**Examples of new helper functions to use:**
- `requestid.FromContext(c)`
- `csrf.TokenFromContext(c)`
- `csrf.HandlerFromContext(c)`
- `session.FromContext(c)`
- `basicauth.UsernameFromContext(c)`
- `keyauth.TokenFromContext(c)`
**For logging these values:**
The recommended approach is to use the `CustomTags` feature of the Logger middleware, which allows you to call these specific `FromContext` functions. Refer to the [Logger section in "What's New"](#logger) for detailed examples.
:::note
If you were manually setting and retrieving your own application-specific values in `c.Locals()` using string keys, that functionality remains unchanged. This change specifically pertains to how Fiber's built-in (and some contrib) middlewares expose their data.
:::
#### BasicAuth
The `Authorizer` callback now receives the current request context. Update custom
functions from:
```go
Authorizer: func(user, pass string) bool {
// v2 style
return user == "admin" && pass == "secret"
}
```
to:
```go
Authorizer: func(user, pass string, _ fiber.Ctx) bool {
// v3 style with access to the Fiber context
return user == "admin" && pass == "secret"
}
```
Passwords configured for BasicAuth must now be pre-hashed. If no prefix is supplied the middleware expects a SHA-256 digest encoded in hex. Common prefixes like `{SHA256}` and `{SHA512}` and bcrypt strings are also supported. Plaintext passwords are no longer accepted. Unauthorized responses also include a `Vary: Authorization` header for correct caching behavior.
You can also set the optional `HeaderLimit` and `Charset`
options to further control authentication behavior.
#### KeyAuth
The keyauth middleware was updated to introduce a configurable `Realm` field for the `WWW-Authenticate` header.
The old string-based `KeyLookup` configuration has been replaced with an `Extractor` field, and the `AuthScheme` field has been removed. The auth scheme is now inferred from the extractor used (e.g., `keyauth.FromAuthHeader`). Use helper functions like `keyauth.FromHeader`, `keyauth.FromAuthHeader`, or `keyauth.FromCookie` to define where the key should be retrieved from. Multiple sources can be combined with `keyauth.Chain`.
New `Challenge`, `Error`, `ErrorDescription`, `ErrorURI`, and `Scope` options let you customize challenge responses, include Bearer error parameters, and specify required scopes. `ErrorURI` values are validated as absolute, credentials containing whitespace are rejected, and when multiple authorization extractors are chained, all schemes are advertised in the `WWW-Authenticate` header. The middleware defers emitting `WWW-Authenticate` until a 401 status is final, and `FromAuthHeader` now trims surrounding whitespace.
```go
// Before
app.Use(keyauth.New(keyauth.Config{
KeyLookup: "header:Authorization",
AuthScheme: "Bearer",
Validator: validateAPIKey,
}))
// After
app.Use(keyauth.New(keyauth.Config{
Extractor: keyauth.FromAuthHeader(fiber.HeaderAuthorization, "Bearer"),
Validator: validateAPIKey,
}))
```
Combine multiple sources with `keyauth.Chain()` when needed.
#### Cache
The deprecated `Store` and `Key` fields were removed. Use `Storage` and
`KeyGenerator` instead to configure caching backends and cache keys.
Defaults also changed: the middleware now emits `Cache-Control` headers, the default `Expiration` increased to `5 minutes` (from `1 minute`), and a new `MaxBytes` limit of `1 MB` (previously unlimited) now caps cached payloads.
To restore v2 behavior:
- Set `DisableCacheControl` to `true` to suppress automatic `Cache-Control` headers.
- Configure `Expiration` to `1*time.Minute`.
- Set `MaxBytes` to `0` (or a higher value) when caching large responses.
#### CORS
The CORS middleware has been updated to use slices instead of strings for the `AllowOrigins`, `AllowMethods`, `AllowHeaders`, and `ExposeHeaders` fields. Here's how you can update your code:
```go
// Before
app.Use(cors.New(cors.Config{
AllowOrigins: "https://example.com,https://example2.com",
AllowMethods: strings.Join([]string{fiber.MethodGet, fiber.MethodPost}, ","),
AllowHeaders: "Content-Type",
ExposeHeaders: "Content-Length",
}))
// After
app.Use(cors.New(cors.Config{
AllowOrigins: []string{"https://example.com", "https://example2.com"},
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost},
AllowHeaders: []string{"Content-Type"},
ExposeHeaders: []string{"Content-Length"},
}))
```
#### CSRF
- **Field Renaming**: The `Expiration` field in the CSRF middleware configuration has been renamed to `IdleTimeout` to better describe its functionality. Additionally, the default value has been reduced from 1 hour to 30 minutes. Update your code as follows:
```go
// Before
app.Use(csrf.New(csrf.Config{
Expiration: 10 * time.Minute,
}))
// After
app.Use(csrf.New(csrf.Config{
IdleTimeout: 10 * time.Minute,
}))
```
- **Session Key Removal**: The `SessionKey` field has been removed from the CSRF middleware configuration. The session key is now an unexported constant within the middleware to avoid potential key collisions in the session store.
- **KeyLookup Field Removal**: The `KeyLookup` field has been removed from the CSRF middleware configuration. This field was deprecated and is no longer needed as the middleware now uses a more secure approach for token management.
- **DisableValueRedaction Toggle**: CSRF redacts tokens and storage keys by default; set `DisableValueRedaction` to `true` when diagnostics require the raw values.
- **Default KeyGenerator**: Changed from `utils.UUIDv4` to `utils.SecureToken`, producing base64-encoded tokens instead of UUID format.
```go
// Before
app.Use(csrf.New(csrf.Config{
KeyLookup: "header:X-Csrf-Token",
// other config...
}))
// After - use Extractor instead
app.Use(csrf.New(csrf.Config{
Extractor: csrf.FromHeader("X-Csrf-Token"),
// other config...
}))
```
- **FromCookie Extractor Removal**: The `csrf.FromCookie` extractor has been intentionally removed for security reasons. Using cookie-based extraction defeats the purpose of CSRF protection by making the extracted token always match the cookie value.
```go
// Before - This was a security vulnerability
app.Use(csrf.New(csrf.Config{
Extractor: csrf.FromCookie("csrf_token"), // ❌ Insecure!
}))
// After - Use secure extractors instead
app.Use(csrf.New(csrf.Config{
Extractor: csrf.FromHeader("X-Csrf-Token"), // ✅ Secure
// or
Extractor: csrf.FromForm("_csrf"), // ✅ Secure
// or
Extractor: csrf.FromQuery("csrf_token"), // ✅ Acceptable
}))
```
**Security Note**: The removal of `FromCookie` prevents a common misconfiguration that would completely bypass CSRF protection. The middleware uses the Double Submit Cookie pattern, which requires the token to be submitted through a different channel than the cookie to provide meaningful protection.
#### Idempotency
- **DisableValueRedaction Toggle**: The idempotency middleware now hides keys in logs and error paths by default, with a `DisableValueRedaction` boolean (default `false`) to reveal them when needed.
#### Timeout
The timeout middleware now accepts a configuration struct instead of a duration.
Update your code as follows:
```go
// Before
app.Use(timeout.New(handler, 2*time.Second))
// After
app.Use(timeout.New(handler, timeout.Config{Timeout: 2 * time.Second}))
```
**Important behavioral changes:**
- The middleware now returns immediately on timeout without waiting for the handler (using the new Abandon mechanism).
- Handlers can detect timeouts by listening on `c.Context().Done()` and return early.
- Panics in the handler are caught and converted to `500 Internal Server Error`.
#### Filesystem
You need to move filesystem middleware to static middleware due to it has been removed from the core.
```go
// Before
app.Use(filesystem.New(filesystem.Config{
Root: http.Dir("./assets"),
}))
app.Use(filesystem.New(filesystem.Config{
Root: http.Dir("./assets"),
Browse: true,
Index: "index.html",
MaxAge: 3600,
}))
```
```go
// After
app.Use(static.New("", static.Config{
FS: os.DirFS("./assets"),
}))
app.Use(static.New("", static.Config{
FS: os.DirFS("./assets"),
Browse: true,
IndexNames: []string{"index.html"},
MaxAge: 3600,
}))
```
#### EnvVar
The `ExcludeVars` option has been removed. Remove any references to it and use
`ExportVars` to explicitly list environment variables that should be exposed.
#### Healthcheck
Previously, the Healthcheck middleware was configured with a combined setup for liveness and readiness probes:
```go
//before
app.Use(healthcheck.New(healthcheck.Config{
LivenessProbe: func(c fiber.Ctx) bool {
return true
},
LivenessEndpoint: "/live",
ReadinessProbe: func(c fiber.Ctx) bool {
return serviceA.Ready() && serviceB.Ready() && ...
},
ReadinessEndpoint: "/ready",
}))
```
With the new version, each health check endpoint is configured separately, allowing for more flexibility:
```go
// after
// Default liveness endpoint configuration
app.Get(healthcheck.LivenessEndpoint, healthcheck.New(healthcheck.Config{
Probe: func(c fiber.Ctx) bool {
return true
},
}))
// Default readiness endpoint configuration
app.Get(healthcheck.ReadinessEndpoint, healthcheck.New())
// New default startup endpoint configuration
// Default endpoint is /startupz
app.Get(healthcheck.StartupEndpoint, healthcheck.New(healthcheck.Config{
Probe: func(c fiber.Ctx) bool {
return serviceA.Ready() && serviceB.Ready() && ...
},
}))
// Custom liveness endpoint configuration
app.Get("/live", healthcheck.New())
```
#### Monitor
Since v3 the Monitor middleware has been moved to the [Contrib package](https://github.com/gofiber/contrib/tree/main/monitor)
```go
// Before
import "github.com/gofiber/fiber/v2/middleware/monitor"
app.Use("/metrics", monitor.New())
```
You only need to change the import path to the contrib package.
```go
// After
import "github.com/gofiber/contrib/monitor"
app.Use("/metrics", monitor.New())
```
#### Proxy
In previous versions, TLS settings for the proxy middleware were set using the `WithTlsConfig` method. This method has been removed in favor of a more idiomatic configuration via the `TLSConfig` field in the `Config` struct.
#### Before (v2 usage)
```go
proxy.WithTlsConfig(&tls.Config{
InsecureSkipVerify: true,
})
// Forward to url
app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif"))
```
#### After (v3 usage)
```go
proxy.WithClient(&fasthttp.Client{
TLSConfig: &tls.Config{InsecureSkipVerify: true},
})
// Forward to url
app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif"))
```
`proxy.Balancer` also adopts the common middleware signature pattern and now accepts an optional variadic config: call `proxy.Balancer()` to use the defaults or continue passing a single `proxy.Config` value as in v2.
#### Session
`session.New()` now returns a middleware handler. When using the store pattern,
create a store with `session.NewStore()` or call `Store()` on the middleware.
Sessions obtained from a store must be released manually via `sess.Release()`.
Additionally, replace the deprecated `KeyLookup` option with extractor
functions such as `session.FromCookie()` or `session.FromHeader()`. Multiple
extractors can be combined with `session.Chain()`.
```go
// Before
app.Use(session.New(session.Config{
KeyLookup: "cookie:session_id",
Store: session.NewStore(),
}))
```
```go
// After
app.Use(session.New(session.Config{
Extractor: session.FromCookie("session_id"),
Store: session.NewStore(),
}))
```
See the [Session Middleware Migration Guide](./middleware/session.md#migration-guide)
for complete details.
================================================
FILE: error.go
================================================
package fiber
import (
"encoding/json"
"errors"
"github.com/gofiber/schema"
)
// Wrap and return this for unreachable code if panicking is undesirable (i.e., in a handler).
// Unexported because users will hopefully never need to see it.
var errUnreachable = errors.New("fiber: unreachable code, please create an issue at github.com/gofiber/fiber")
// General errors
var (
ErrGracefulTimeout = errors.New("shutdown: graceful timeout has been reached, exiting")
// ErrNotRunning indicates that a Shutdown method was called when the server was not running.
ErrNotRunning = errors.New("shutdown: server is not running")
// ErrHandlerExited is returned by App.Test if a handler panics or calls runtime.Goexit().
ErrHandlerExited = errors.New("runtime.Goexit() called in handler or server panic")
// ErrNoViewEngineConfigured indicates that a helper requiring a view engine was invoked without one configured.
ErrNoViewEngineConfigured = errors.New("fiber: no view engine configured")
// ErrAutoCertWithCertFile indicates AutoCertManager cannot be used with CertFile/CertKeyFile.
ErrAutoCertWithCertFile = errors.New("tls: AutoCertManager cannot be combined with CertFile/CertKeyFile")
)
// Fiber redirection errors
var (
ErrRedirectBackNoFallback = NewError(StatusInternalServerError, "Referer not found, you have to enter fallback URL for redirection.")
)
// Range errors
var (
ErrRangeMalformed = errors.New("range: malformed range header string")
ErrRangeTooLarge = NewError(StatusRequestedRangeNotSatisfiable, "range: too many ranges")
ErrRangeUnsatisfiable = errors.New("range: unsatisfiable range")
)
// Binder errors
var ErrCustomBinderNotFound = errors.New("binder: custom binder not found, please be sure to enter the right name")
// Format errors
var (
// ErrNoHandlers is returned when c.Format is called with no arguments.
ErrNoHandlers = errors.New("format: at least one handler is required, but none were set")
)
// gofiber/schema errors
type (
// ConversionError Conversion error exposes the internal schema.ConversionError for public use.
ConversionError = schema.ConversionError
// UnknownKeyError error exposes the internal schema.UnknownKeyError for public use.
UnknownKeyError = schema.UnknownKeyError
// EmptyFieldError error exposes the internal schema.EmptyFieldError for public use.
EmptyFieldError = schema.EmptyFieldError
// MultiError error exposes the internal schema.MultiError for public use.
MultiError = schema.MultiError
)
// encoding/json errors
type (
// InvalidUnmarshalError describes an invalid argument passed to Unmarshal.
// (The argument to Unmarshal must be a non-nil pointer.)
InvalidUnmarshalError = json.InvalidUnmarshalError
// MarshalerError represents an error from calling a MarshalJSON or MarshalText method.
MarshalerError = json.MarshalerError
// SyntaxError is a description of a JSON syntax error.
SyntaxError = json.SyntaxError
// UnmarshalTypeError describes a JSON value that was
// not appropriate for a value of a specific Go type.
UnmarshalTypeError = json.UnmarshalTypeError
// UnsupportedTypeError is returned by Marshal when attempting
// to encode an unsupported value type.
UnsupportedTypeError = json.UnsupportedTypeError
// UnsupportedValueError exposes json.UnsupportedValueError to describe unsupported values encountered during encoding.
UnsupportedValueError = json.UnsupportedValueError
)
================================================
FILE: error_test.go
================================================
package fiber
import (
"encoding/json"
"errors"
"testing"
"github.com/gofiber/schema"
"github.com/stretchr/testify/require"
)
func Test_ConversionError(t *testing.T) {
t.Parallel()
ok := errors.As(ConversionError{}, &schema.ConversionError{})
require.True(t, ok)
}
func Test_UnknownKeyError(t *testing.T) {
t.Parallel()
ok := errors.As(UnknownKeyError{}, &schema.UnknownKeyError{})
require.True(t, ok)
}
func Test_EmptyFieldError(t *testing.T) {
t.Parallel()
ok := errors.As(EmptyFieldError{}, &schema.EmptyFieldError{})
require.True(t, ok)
}
func Test_MultiError(t *testing.T) {
t.Parallel()
ok := errors.As(MultiError{}, &schema.MultiError{})
require.True(t, ok)
}
func Test_InvalidUnmarshalError(t *testing.T) {
t.Parallel()
var e *json.InvalidUnmarshalError
ok := errors.As(&InvalidUnmarshalError{}, &e)
require.True(t, ok)
}
func Test_MarshalerError(t *testing.T) {
t.Parallel()
var e *json.MarshalerError
ok := errors.As(&MarshalerError{}, &e)
require.True(t, ok)
}
func Test_SyntaxError(t *testing.T) {
t.Parallel()
var e *json.SyntaxError
ok := errors.As(&SyntaxError{}, &e)
require.True(t, ok)
}
func Test_UnmarshalTypeError(t *testing.T) {
t.Parallel()
var e *json.UnmarshalTypeError
ok := errors.As(&UnmarshalTypeError{}, &e)
require.True(t, ok)
}
func Test_UnsupportedTypeError(t *testing.T) {
t.Parallel()
var e *json.UnsupportedTypeError
ok := errors.As(&UnsupportedTypeError{}, &e)
require.True(t, ok)
}
func Test_UnsupportedValeError(t *testing.T) {
t.Parallel()
var e *json.UnsupportedValueError
ok := errors.As(&UnsupportedValueError{}, &e)
require.True(t, ok)
}
================================================
FILE: errors_internal.go
================================================
package fiber
import (
"errors"
)
var (
errBindPoolTypeAssertion = errors.New("failed to type-assert to *Bind")
errCustomCtxTypeAssertion = errors.New("failed to type-assert to CustomCtx")
errInvalidEscapeSequence = errors.New("invalid escape sequence")
errRedirectTypeAssertion = errors.New("failed to type-assert to *Redirect")
)
================================================
FILE: extractors/README.md
================================================
# Extractors Package
Package providing shared value extraction utilities for Fiber middleware packages.
## Audience
**This README is targeted at middleware developers and contributors.** If you are a Fiber framework user looking to use extractors in your application, please refer to the [Extractors Guide](https://docs.gofiber.io/guide/extractors) instead.
## Architecture
### Core Types
- `Extractor`: Core extraction function with metadata
- `Source`: Enumeration of extraction sources (Header, AuthHeader, Query, Form, Param, Cookie, Custom)
- `ErrNotFound`: Standardized error for missing values
### Extractor Structure
```go
type Extractor struct {
Extract func(fiber.Ctx) (string, error)
Key string // The parameter/header name used for extraction
AuthScheme string // The auth scheme used, e.g., "Bearer"
Chain []Extractor // For chained extractors, stores all extractors in the chain
Source Source // The type of source being extracted from
}
```
### Available Functions
- `FromAuthHeader(authScheme string)`: Extract from Authorization header with optional scheme
- `FromCookie(key string)`: Extract from HTTP cookies
- `FromParam(param string)`: Extract from URL path parameters
- `FromForm(param string)`: Extract from form data
- `FromHeader(header string)`: Extract from custom HTTP headers
- `FromQuery(param string)`: Extract from URL query parameters
- `FromCustom(key string, fn func(fiber.Ctx) (string, error))`: Define custom extraction logic with metadata
- `Chain(extractors ...Extractor)`: Chain multiple extractors with fallback
### Source Inspection
The `Source` field provides **security-aware extraction** by explicitly identifying the origin of extracted values. This enables middleware to enforce security policies based on data source:
```go
switch extractor.Source {
case SourceAuthHeader:
// Authorization header - commonly used for authentication tokens
case SourceHeader:
// Custom HTTP headers - application-specific data
case SourceCookie:
// HTTP cookies - client-side stored data
case SourceQuery:
// URL query parameters - visible in URLs and logs (security consideration)
case SourceForm:
// Form data - POST body data
case SourceParam:
// URL path parameters - route-based data
case SourceCustom:
// Custom extraction logic
}
```
### Chain Behavior
The `Chain` function implements fallback logic:
- Returns first successful extraction (non-empty value, no error)
- If all extractors fail, returns the last error encountered or `ErrNotFound`
- **Skips extractors with `nil` Extract functions** (graceful error handling)
- Preserves metadata from first extractor for introspection
- Stores defensive copy for runtime inspection via the `Chain` field
## Security Considerations
### Source Awareness and Custom Extractors
As described in the [Source Inspection](#source-inspection) section, the `Source` field enables middleware to enforce security policies based on data source:
- **CSRF Protection**: The double-submit-cookie pattern requires tokens to be submitted in both a cookie AND a form field/header. Source awareness allows CSRF middleware to verify that tokens come from both expected sources, and not for example only from cookies
- **Authentication**: Security middleware can enforce source-specific policies (e.g., auth tokens from headers, not query parameters)
- **Audit Trails**: Source information enables security analysis and compliance reporting
However, when using `FromCustom`, middleware cannot determine the source of the extracted value, which can limit the ability of a middleware to provide warnings about potential security risks. Documentation and examples should clearly warn about these risks when using custom extractors.
================================================
FILE: extractors/extractors.go
================================================
package extractors
// Package extractors provides shared value extraction utilities for Fiber middleware.
// This package helps reduce code duplication across middleware packages
// while ensuring consistent behavior, security practices, and RFC compliance.
// It can extract string values from various HTTP request sources including
// headers, cookies, query parameters, form data, and URL parameters.
//
// Example usage:
//
// import "github.com/gofiber/fiber/v3/extractors"
//
// // Extract from Authorization header
// authExtractor := extractors.FromAuthHeader("Bearer")
//
// // Chain multiple sources with fallback
// tokenExtractor := extractors.Chain(
// extractors.FromHeader("X-API-Key"),
// extractors.FromCookie("api_key"),
// extractors.FromQuery("token"),
// )
//
// Security considerations:
// - Query parameters and form data can leak sensitive information
// - Use HTTPS to protect extracted values in transit
// - Consider source-specific security policies for your use case
import (
"errors"
"net/url"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
)
// Source represents the type of source from which an API key is extracted.
// This is informational metadata that helps developers understand the extractor behavior.
type Source int
const (
// SourceHeader indicates the value is extracted from an HTTP header.
SourceHeader Source = iota
// SourceAuthHeader indicates the value is extracted from the Authorization header.
SourceAuthHeader
// SourceForm indicates the value is extracted from form data.
SourceForm
// SourceQuery indicates the value is extracted from URL query parameters.
SourceQuery
// SourceParam indicates the value is extracted from URL path parameters.
SourceParam
// SourceCookie indicates the value is extracted from cookies.
SourceCookie
// SourceCustom indicates the value is extracted using a custom extractor function.
SourceCustom
)
// ErrNotFound is returned when the requested value is missing or empty.
var ErrNotFound = errors.New("value not found")
// Extractor defines a value extraction method with metadata.
type Extractor struct {
Extract func(fiber.Ctx) (string, error)
Key string // The parameter/header name used for extraction
AuthScheme string // The auth scheme used, e.g., "Bearer"
Chain []Extractor // For chained extractors, stores all extractors in the chain
Source Source // The type of source being extracted from
}
// FromAuthHeader extracts a value from the Authorization header with an optional prefix.
// This function implements RFC 9110 compliant Authorization header parsing with strict token68 validation.
//
// RFC Compliance:
// - Follows RFC 9110 Section 11.6.2 for Authorization header format
// - Enforces 1*SP (one or more spaces) between auth-scheme and credentials
// - Implements RFC 7235 token68 character validation for extracted tokens
// - Case-insensitive auth scheme matching per HTTP standards
//
// Token68 Validation:
// - Only allows characters: A-Z, a-z, 0-9, -, ., _, ~, +, /, =
// - Rejects tokens containing spaces, tabs, or other whitespace
// - Validates proper padding: = only at end, no characters after padding starts
// - Prevents tokens starting with = (invalid padding)
//
// Security Features:
// - Strict validation prevents header injection attacks
// - Rejects malformed tokens that could bypass authentication
// - Consistent error handling for missing or invalid credentials
//
// Parameters:
// - authScheme: The auth scheme to strip from the header value (e.g., "Bearer", "Basic").
// If empty, the entire header value is returned without validation.
//
// Returns:
//
// An Extractor that attempts to retrieve and parse the Authorization header.
// Returns ErrNotFound if the header is missing, malformed, or doesn't match the expected scheme.
//
// Examples:
//
// // Extract Bearer token with validation
// extractor := FromAuthHeader("Bearer")
// // Input: "Bearer abc123" -> Output: "abc123"
// // Input: "Bearer abc def" -> Output: ErrNotFound (space in token)
// // Input: "Basic dXNlcjpwYXNz" -> Output: ErrNotFound (wrong scheme)
//
// // Extract raw header value (no validation)
// extractor := FromAuthHeader("")
// // Input: "CustomAuth token123" -> Output: "CustomAuth token123"
func FromAuthHeader(authScheme string) Extractor {
return Extractor{
Extract: func(c fiber.Ctx) (string, error) {
authHeader := c.Get(fiber.HeaderAuthorization)
if authHeader == "" {
return "", ErrNotFound
}
// Check if the header starts with the specified auth scheme
if authScheme != "" {
schemeLen := len(authScheme)
if len(authHeader) <= schemeLen || !utils.EqualFold(authHeader[:schemeLen], authScheme) {
return "", ErrNotFound
}
rest := authHeader[schemeLen:]
if rest == "" || rest[0] != ' ' {
return "", ErrNotFound
}
// Extract token after the required space
token := rest[1:]
if token == "" {
return "", ErrNotFound
}
if !isValidToken68(token) {
return "", ErrNotFound
}
return token, nil
}
return authHeader, nil
},
Key: fiber.HeaderAuthorization,
Source: SourceAuthHeader,
AuthScheme: authScheme,
}
}
// FromCookie creates an Extractor that retrieves a value from a specified cookie in the request.
//
// The function:
// - Retrieves the cookie value using the specified name
// - Returns ErrNotFound if the cookie is missing
//
// Parameters:
// - key: The name of the cookie from which to extract the value.
//
// Returns:
//
// An Extractor that attempts to retrieve the value from the specified cookie.
// Returns ErrNotFound if the cookie is not present.
//
// Security Note:
//
// Cookies are generally more secure than query parameters for sensitive data
// as they are not logged in access logs or visible in browser history.
// However, ensure cookies are properly secured with appropriate flags.
//
// Example:
//
// extractor := FromCookie("session_id")
// // Cookie: "session_id=abc123" -> Output: "abc123"
// // Missing cookie -> Output: ErrNotFound
func FromCookie(key string) Extractor {
return Extractor{
Extract: func(c fiber.Ctx) (string, error) {
value := c.Cookies(key)
if value == "" {
return "", ErrNotFound
}
return value, nil
},
Key: key,
Source: SourceCookie,
}
}
// FromParam creates an Extractor that retrieves a value from a specified URL parameter in the request.
// URL parameters are extracted from the route path (e.g., /users/:id).
//
// SECURITY WARNING: Extracting values from URL parameters can leak sensitive information through:
// - Server access logs and error logs
// - Browser referrer headers when following links
// - Proxy and intermediary server logs
// - Browser history and bookmarks
// - Network monitoring tools
//
// For sensitive data, prefer FromAuthHeader, FromCookie, or FromHeader instead.
//
// Parameters:
// - param: The name of the URL parameter from which to extract the value.
//
// Returns:
//
// An Extractor that attempts to retrieve the value from the specified URL parameter.
// Returns ErrNotFound if the parameter is not present.
//
// Example:
//
// // Route: GET /users/:userId/posts/:postId
// userExtractor := FromParam("userId")
// postExtractor := FromParam("postId")
// // URL: /users/123/posts/456 -> userId: "123", postId: "456"
func FromParam(param string) Extractor {
return Extractor{
Extract: func(c fiber.Ctx) (string, error) {
value := c.Params(param)
if value == "" {
return "", ErrNotFound
}
unescapedValue, err := url.PathUnescape(value)
if err != nil {
return "", ErrNotFound
}
return unescapedValue, nil
},
Key: param,
Source: SourceParam,
}
}
// FromForm creates an Extractor that retrieves a value from a specified form field in the request.
// Form data is typically submitted via POST requests with content-type application/x-www-form-urlencoded.
//
// SECURITY WARNING: Extracting values from form data can leak sensitive information through:
// - Server access logs and error logs
// - Browser referrer headers (especially if form is submitted via GET)
// - Proxy and intermediary server logs
// - Browser history (if form uses GET method)
//
// For sensitive data, prefer FromAuthHeader or FromCookie instead.
// If using form data, ensure the form uses POST method and HTTPS.
//
// Parameters:
// - param: The name of the form field from which to extract the value.
//
// Returns:
//
// An Extractor that attempts to retrieve the value from the specified form field.
// Returns ErrNotFound if the field is not present.
//
// Example:
//
// extractor := FromForm("username")
// // Form data: "username=john_doe&password=secret" -> Output: "john_doe"
// // Missing field -> Output: ErrNotFound
func FromForm(param string) Extractor {
return Extractor{
Extract: func(c fiber.Ctx) (string, error) {
value := c.FormValue(param)
if value == "" {
return "", ErrNotFound
}
return value, nil
},
Key: param,
Source: SourceForm,
}
}
// FromHeader creates an Extractor that retrieves a value from a specified HTTP header in the request.
// HTTP headers are commonly used for API keys, tokens, and other metadata.
//
// The function:
// - Retrieves the header value using the specified name
// - Returns ErrNotFound if the header is missing
//
// Parameters:
// - header: The name of the HTTP header from which to extract the value.
//
// Returns:
//
// An Extractor that attempts to retrieve the value from the specified HTTP header.
// Returns ErrNotFound if the header is not present.
//
// Security Note:
//
// Headers are generally secure for sensitive data as they are not logged
// in access logs by default. However, be aware that some proxies may log headers.
//
// Example:
//
// extractor := FromHeader("X-API-Key")
// // Header: "X-API-Key: abc123" -> Output: "abc123"
// // Missing header -> Output: ErrNotFound
func FromHeader(header string) Extractor {
return Extractor{
Extract: func(c fiber.Ctx) (string, error) {
value := c.Get(header)
if value == "" {
return "", ErrNotFound
}
return value, nil
},
Key: header,
Source: SourceHeader,
}
}
// FromQuery creates an Extractor that retrieves a value from a specified query parameter in the request.
// Query parameters are extracted from the URL query string (e.g., ?key=value&foo=bar).
//
// SECURITY WARNING: Extracting values from URL query parameters can leak sensitive information through:
// - Server access logs and error logs
// - Browser referrer headers when following links
// - Proxy and intermediary server logs
// - Browser history and bookmarks
// - Network monitoring tools and packet sniffers
// - Web browser developer tools
//
// For sensitive data, prefer FromAuthHeader, FromCookie, or FromHeader instead.
// If query parameters must be used, ensure HTTPS is enforced.
//
// Parameters:
// - param: The name of the query parameter from which to extract the value.
//
// Returns:
//
// An Extractor that attempts to retrieve the value from the specified query parameter.
// Returns ErrNotFound if the parameter is not present.
//
// Example:
//
// extractor := FromQuery("token")
// // URL: /api/data?token=abc123&format=json -> Output: "abc123"
// // URL: /api/data?format=json -> Output: ErrNotFound
func FromQuery(param string) Extractor {
return Extractor{
Extract: func(c fiber.Ctx) (string, error) {
value := c.Query(param)
if value == "" {
return "", ErrNotFound
}
return value, nil
},
Key: param,
Source: SourceQuery,
}
}
// FromCustom creates an Extractor using a provided function.
// This allows for custom extraction logic beyond the built-in extractors.
//
// The function:
// - Accepts a custom extraction function with signature func(fiber.Ctx) (string, error)
// - Handles nil functions gracefully by returning ErrNotFound
// - Preserves the custom function for execution
//
// Parameters:
// - key: A descriptive identifier for the custom extractor.
// Used for debugging, logging, and Chain metadata. Should be meaningful for introspection.
// Examples: "X-Custom-Header", "Database-Lookup", "Cache-Key"
// - fn: The custom function to extract the value from the fiber.Ctx.
// If nil, the extractor will return ErrNotFound when executed.
// The function should return (value, nil) on success or ("", error) on failure.
//
// Returns:
//
// An Extractor that uses the provided function for extraction.
// If fn is nil, the returned extractor will always return ErrNotFound.
//
// Examples:
//
// // Custom header with transformation
// extractor := FromCustom("X-API-Key", func(c fiber.Ctx) (string, error) {
// value := c.Get("X-API-Key")
// if value == "" {
// return "", ErrNotFound
// }
// return strings.ToUpper(value), nil
// })
//
// // Database lookup (pseudo-code)
// userExtractor := FromCustom("user-from-db", func(c fiber.Ctx) (string, error) {
// userID := c.Params("userId")
// user, err := db.GetUser(userID)
// if err != nil {
// return "", err
// }
// return user.Name, nil
// })
//
// // Conditional extraction
// smartExtractor := FromCustom("smart-auth", func(c fiber.Ctx) (string, error) {
// if c.Get("X-Service-Auth") != "" {
// return c.Get("X-Service-Auth"), nil
// }
// return c.Cookies("session"), nil
// })
func FromCustom(key string, fn func(fiber.Ctx) (string, error)) Extractor {
if fn == nil {
fn = func(fiber.Ctx) (string, error) { return "", ErrNotFound }
}
return Extractor{
Extract: fn,
Key: key,
Source: SourceCustom,
}
}
// Chain creates an Extractor that tries multiple extractors in order until one succeeds.
// This implements a fallback pattern where multiple extraction sources are attempted in sequence.
//
// The function:
// - Tries each extractor in the order provided
// - Returns the first successful extraction (non-empty value with no error)
// - Skips extractors with nil Extract functions
// - Returns the last error encountered if all extractors fail
// - Returns ErrNotFound if no extractors are provided or all return empty values
//
// Parameters:
// - extractors: A variadic list of Extractor instances to try in sequence.
// The order matters - more secure/preferred sources should be listed first.
//
// Returns:
//
// An Extractor that attempts each provided extractor in order.
// The returned extractor uses the Source and Key from the first extractor for metadata.
//
// Behavior:
// - Success: Returns the first non-empty value with no error
// - Partial failure: Continues to next extractor if current returns error or empty value
// - Total failure: Returns last error encountered, or ErrNotFound if no errors
// - Empty chain: Always returns ErrNotFound
//
// Examples:
//
// // Try header first, then cookie, then query param
// extractor := Chain(
// FromHeader("Authorization"),
// FromCookie("auth_token"),
// FromQuery("token"),
// )
//
// // API key from multiple possible sources
// apiKeyExtractor := Chain(
// FromHeader("X-API-Key"),
// FromQuery("api_key"),
// FromForm("apiKey"),
// )
//
// Security Note:
//
// Order extractors by security preference. Most secure sources (headers, cookies)
// should be attempted before less secure ones (query params, form data).
func Chain(extractors ...Extractor) Extractor {
if len(extractors) == 0 {
return Extractor{
Extract: func(fiber.Ctx) (string, error) {
return "", ErrNotFound
},
Source: SourceCustom,
Key: "",
Chain: []Extractor{},
}
}
// Use the source and key from the first extractor as the primary
primarySource := extractors[0].Source
primaryKey := extractors[0].Key
return Extractor{
Extract: func(c fiber.Ctx) (string, error) {
var lastErr error // last error encountered (including ErrNotFound)
for _, extractor := range extractors {
if extractor.Extract == nil {
continue
}
v, err := extractor.Extract(c)
if err == nil && v != "" {
return v, nil
}
if err != nil {
lastErr = err
}
}
if lastErr != nil {
return "", lastErr
}
return "", ErrNotFound
},
Source: primarySource,
Key: primaryKey,
Chain: append([]Extractor(nil), extractors...), // Defensive copy for introspection
}
}
// isValidToken68 checks if a string is a valid token68 per RFC 7235/9110.
func isValidToken68(token string) bool {
if token == "" {
return false
}
paddingStarted := false
for i := 0; i < len(token); i++ {
c := token[i]
switch {
case (c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') ||
c == '-' || c == '.' || c == '_' || c == '~' || c == '+' || c == '/':
if paddingStarted {
return false // No characters allowed after padding starts
}
case c == '=':
if i == 0 {
return false // Cannot start with padding
}
paddingStarted = true
default:
return false // Invalid character
}
}
return true
}
================================================
FILE: extractors/extractors_test.go
================================================
package extractors
import (
"context"
"net/http"
"strings"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
// go test -run Test_Extractors_Missing
func Test_Extractors_Missing(t *testing.T) {
t.Parallel()
app := fiber.New()
// Add a route to test the missing param
app.Get("/test", func(c fiber.Ctx) error {
token, err := FromParam("token").Extract(c)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound)
return nil
})
_, err := app.Test(newRequest(fiber.MethodGet, "/test"))
require.NoError(t, err)
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
// Missing form
token, err := FromForm("token").Extract(ctx)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound)
// Missing query
token, err = FromQuery("token").Extract(ctx)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound)
// Missing header
token, err = FromHeader("X-Token").Extract(ctx)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound)
// Missing Auth header
token, err = FromAuthHeader("Bearer").Extract(ctx)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound)
// Missing cookie
token, err = FromCookie("token").Extract(ctx)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound)
}
// newRequest creates a new *http.Request for Fiber's app.Test
func newRequest(method, target string) *http.Request {
req, err := http.NewRequestWithContext(context.Background(), method, target, http.NoBody)
if err != nil {
panic(err)
}
return req
}
// go test -run Test_Extractors
func Test_Extractors(t *testing.T) {
t.Parallel()
t.Run("FromParam", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test/:token", func(c fiber.Ctx) error {
token, err := FromParam("token").Extract(c)
require.NoError(t, err)
require.Equal(t, "token_from_param", token)
return nil
})
_, err := app.Test(newRequest(fiber.MethodGet, "/test/token_from_param"))
require.NoError(t, err)
})
t.Run("FromForm", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.SetContentType(fiber.MIMEApplicationForm)
ctx.Request().Header.SetMethod(fiber.MethodPost)
ctx.Request().SetBodyString("token=token_from_form")
token, err := FromForm("token").Extract(ctx)
require.NoError(t, err)
require.Equal(t, "token_from_form", token)
})
t.Run("FromQuery", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().SetRequestURI("/?token=token_from_query")
token, err := FromQuery("token").Extract(ctx)
require.NoError(t, err)
require.Equal(t, "token_from_query", token)
})
t.Run("FromHeader", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set("X-Token", "token_from_header")
token, err := FromHeader("X-Token").Extract(ctx)
require.NoError(t, err)
require.Equal(t, "token_from_header", token)
})
t.Run("FromAuthHeader", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set(fiber.HeaderAuthorization, "Bearer token_from_auth_header")
token, err := FromAuthHeader("Bearer").Extract(ctx)
require.NoError(t, err)
require.Equal(t, "token_from_auth_header", token)
})
t.Run("FromCookie", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.SetCookie("token", "token_from_cookie")
token, err := FromCookie("token").Extract(ctx)
require.NoError(t, err)
require.Equal(t, "token_from_cookie", token)
})
}
// go test -run Test_Extractor_Chain
func Test_Extractor_Chain(t *testing.T) {
t.Parallel()
t.Run("no_extractors", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
token, err := Chain().Extract(ctx)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound)
})
t.Run("first_extractor_succeeds", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set("X-Token", "token_from_header")
ctx.Request().SetRequestURI("/?token=token_from_query")
token, err := Chain(FromHeader("X-Token"), FromQuery("token")).Extract(ctx)
require.NoError(t, err)
require.Equal(t, "token_from_header", token)
})
t.Run("second_extractor_succeeds", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().SetRequestURI("/?token=token_from_query")
token, err := Chain(FromHeader("X-Token"), FromQuery("token")).Extract(ctx)
require.NoError(t, err)
require.Equal(t, "token_from_query", token)
})
t.Run("all_extractors_fail", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
token, err := Chain(FromHeader("X-Token"), FromQuery("token")).Extract(ctx)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound)
})
t.Run("empty_extractor_returns_not_found", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
// This extractor will return "", nil
dummyExtractor := Extractor{
Extract: func(_ fiber.Ctx) (string, error) {
return "", nil
},
Source: SourceCustom,
Key: "token",
}
token, err := Chain(dummyExtractor).Extract(ctx)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound)
})
}
// go test -run Test_Extractor_FromAuthHeader_EdgeCases
func Test_Extractor_FromAuthHeader_EdgeCases(t *testing.T) {
t.Parallel()
t.Run("wrong_scheme", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set(fiber.HeaderAuthorization, "Basic dXNlcjpwYXNz") // Basic auth instead of Bearer
token, err := FromAuthHeader("Bearer").Extract(ctx)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound)
})
t.Run("missing_space_after_scheme", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set(fiber.HeaderAuthorization, "Bearertoken") // Missing space after Bearer
token, err := FromAuthHeader("Bearer").Extract(ctx)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound)
})
t.Run("case_insensitive_scheme_matching", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set(fiber.HeaderAuthorization, "bearer token") // lowercase bearer
token, err := FromAuthHeader("Bearer").Extract(ctx)
require.NoError(t, err)
require.Equal(t, "token", token)
})
}
// go test -run Test_Extractor_Chain_Introspection
func Test_Extractor_Chain_Introspection(t *testing.T) {
t.Parallel()
// Test chain introspection
extractor1 := FromHeader("X-Token")
extractor2 := FromQuery("token")
extractor3 := FromCookie("auth")
chainExtractor := Chain(extractor1, extractor2, extractor3)
// Verify chain metadata
require.Equal(t, SourceHeader, chainExtractor.Source)
require.Equal(t, "X-Token", chainExtractor.Key)
require.Len(t, chainExtractor.Chain, 3)
// Verify individual extractors in chain
require.Equal(t, SourceHeader, chainExtractor.Chain[0].Source)
require.Equal(t, "X-Token", chainExtractor.Chain[0].Key)
require.Equal(t, SourceQuery, chainExtractor.Chain[1].Source)
require.Equal(t, "token", chainExtractor.Chain[1].Key)
require.Equal(t, SourceCookie, chainExtractor.Chain[2].Source)
require.Equal(t, "auth", chainExtractor.Chain[2].Key)
}
// go test -run Test_Extractor_FromCustom
func Test_Extractor_FromCustom(t *testing.T) {
t.Parallel()
t.Run("successful_extraction", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set("X-Custom", "custom-value")
customExtractor := FromCustom("X-Custom", func(c fiber.Ctx) (string, error) {
value := c.Get("X-Custom")
if value == "" {
return "", ErrNotFound
}
return strings.ToUpper(value), nil
})
token, err := customExtractor.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "CUSTOM-VALUE", token)
// Verify metadata
require.Equal(t, SourceCustom, customExtractor.Source)
require.Equal(t, "X-Custom", customExtractor.Key)
require.Empty(t, customExtractor.AuthScheme)
})
t.Run("extraction_with_error", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
errorExtractor := FromCustom("test", func(_ fiber.Ctx) (string, error) {
return "", fiber.NewError(fiber.StatusBadRequest, "Custom error")
})
token, err := errorExtractor.Extract(ctx)
require.Empty(t, token)
require.Error(t, err)
require.Contains(t, err.Error(), "Custom error")
})
t.Run("extraction_returning_empty_string", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
emptyExtractor := FromCustom("empty", func(_ fiber.Ctx) (string, error) {
return "", nil
})
token, err := emptyExtractor.Extract(ctx)
require.Empty(t, token)
require.NoError(t, err) // Should return empty string with no error
})
t.Run("nil_function", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
nilExtractor := FromCustom("nil", nil)
token, err := nilExtractor.Extract(ctx)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound) // Should return ErrNotFound for nil function
})
}
// go test -run Test_Extractor_Chain_Error_Propagation
func Test_Extractor_Chain_Error_Propagation(t *testing.T) {
t.Parallel()
app := fiber.New()
// Create extractors that return different errors
errorExtractor1 := Extractor{
Extract: func(_ fiber.Ctx) (string, error) {
return "", fiber.NewError(fiber.StatusBadRequest, "First error")
},
Key: "error1",
Source: SourceCustom,
}
errorExtractor2 := Extractor{
Extract: func(_ fiber.Ctx) (string, error) {
return "", fiber.NewError(fiber.StatusUnauthorized, "Second error")
},
Key: "error2",
Source: SourceCustom,
}
chainExtractor := Chain(errorExtractor1, errorExtractor2)
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
token, err := chainExtractor.Extract(ctx)
require.Empty(t, token)
require.Error(t, err)
require.Contains(t, err.Error(), "Second error") // Should return the last error
var fe *fiber.Error
require.ErrorAs(t, err, &fe)
require.Equal(t, fiber.StatusUnauthorized, fe.Code)
}
// go test -run Test_Extractor_Chain_With_Success
func Test_Extractor_Chain_With_Success(t *testing.T) {
t.Parallel()
app := fiber.New()
// First extractor fails, second succeeds
failingExtractor := Extractor{
Extract: func(_ fiber.Ctx) (string, error) {
return "", ErrNotFound
},
Key: "fail",
Source: SourceCustom,
}
successExtractor := Extractor{
Extract: func(_ fiber.Ctx) (string, error) {
return "success-token", nil
},
Key: "success",
Source: SourceCustom,
}
chainExtractor := Chain(failingExtractor, successExtractor)
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
token, err := chainExtractor.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "success-token", token)
}
// go test -run Test_Extractor_FromAuthHeader_CustomScheme
func Test_Extractor_FromAuthHeader_CustomScheme(t *testing.T) {
t.Parallel()
app := fiber.New()
// Test with custom auth scheme
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set(fiber.HeaderAuthorization, "CustomScheme my-token")
extractor := FromAuthHeader("CustomScheme")
token, err := extractor.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "my-token", token)
// Verify metadata
require.Equal(t, SourceAuthHeader, extractor.Source)
require.Equal(t, fiber.HeaderAuthorization, extractor.Key)
require.Equal(t, "CustomScheme", extractor.AuthScheme)
}
// go test -run Test_Extractor_FromAuthHeader_WhitespaceToken
func Test_Extractor_FromAuthHeader_WhitespaceToken(t *testing.T) {
t.Parallel()
app := fiber.New()
// Test with token containing whitespace (should be rejected per RFC 7235 token68 spec)
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set(fiber.HeaderAuthorization, "Bearer token with spaces and\ttabs")
extractor := FromAuthHeader("Bearer")
token, err := extractor.Extract(ctx)
require.Error(t, err)
require.ErrorIs(t, err, ErrNotFound)
require.Empty(t, token)
// Verify metadata
require.Equal(t, SourceAuthHeader, extractor.Source)
require.Equal(t, fiber.HeaderAuthorization, extractor.Key)
require.Equal(t, "Bearer", extractor.AuthScheme)
}
// go test -run Test_Extractor_FromAuthHeader_RFC_Compliance
func Test_Extractor_FromAuthHeader_RFC_Compliance(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
header string
expectedToken string
description string
shouldFail bool
}{
{
name: "tab_after_scheme",
header: "Bearer\ttoken",
shouldFail: true,
description: "tab character after scheme should be rejected - RFC specifies 1*SP, not tabs",
},
{
name: "single_space_after_scheme",
header: "Bearer token",
shouldFail: false,
expectedToken: "token",
description: "single space after scheme should be accepted - standard format",
},
{
name: "multiple_spaces_after_scheme",
header: "Bearer token",
shouldFail: true,
description: "multiple spaces after scheme rejected for simplicity - single space is standard",
},
{
name: "mixed_whitespace_after_scheme",
header: "Bearer \t \ttoken",
shouldFail: true,
description: "mixed whitespace after scheme should be rejected - RFC specifies 1*SP, not tabs",
},
{
name: "no_whitespace_after_scheme",
header: "Bearertoken",
shouldFail: true,
description: "no whitespace after scheme should fail",
},
{
name: "header_too_short",
header: "Bearer",
shouldFail: true,
description: "header too short for scheme + space + token",
},
{
name: "only_whitespace_after_scheme",
header: "Bearer \t ",
shouldFail: true,
description: "only whitespace after scheme should fail",
},
{
name: "case_insensitive_scheme",
header: "BEARER token",
shouldFail: false,
expectedToken: "token",
description: "case-insensitive scheme matching should work",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set(fiber.HeaderAuthorization, tc.header)
token, err := FromAuthHeader("Bearer").Extract(ctx)
if tc.shouldFail {
require.Error(t, err, "Expected error for %s", tc.description)
require.ErrorIs(t, err, ErrNotFound)
require.Empty(t, token)
} else {
require.NoError(t, err, "Expected no error for %s", tc.description)
require.Equal(t, tc.expectedToken, token)
}
})
}
// Special case for case-insensitive scheme matching with different extractor scheme
t.Run("case_insensitive_extractor_scheme", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set(fiber.HeaderAuthorization, "BEARER token")
token, err := FromAuthHeader("bearer").Extract(ctx) // lowercase extractor scheme
require.NoError(t, err)
require.Equal(t, "token", token)
})
}
// go test -run Test_Extractor_FromAuthHeader_Token68_Validation
func Test_Extractor_FromAuthHeader_Token68_Validation(t *testing.T) {
t.Parallel()
app := fiber.New()
// Test valid token68 characters (should pass)
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set(fiber.HeaderAuthorization, "Bearer ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~+/=")
token, err := FromAuthHeader("Bearer").Extract(ctx)
require.NoError(t, err)
require.Equal(t, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~+/=", token)
// Test tokens with spaces (should fail)
testCases := []struct {
name string
header string
description string
shouldFail bool
}{
{name: "space_in_token", header: "Bearer abc def", shouldFail: true, description: "space in token"},
{name: "space_after_scheme", header: "Bearer abc", shouldFail: true, description: "multiple spaces after scheme"},
{name: "no_space_after_scheme", header: "Bearertoken", shouldFail: true, description: "no space after scheme"},
{name: "only_scheme", header: "Bearer", shouldFail: true, description: "only scheme, no token"},
{name: "tab_after_scheme", header: "Bearer\ttoken", shouldFail: true, description: "tab after scheme"},
{name: "tab_in_token", header: "Bearer abc\tdef", shouldFail: true, description: "tab in token"},
{name: "newline_in_token", header: "Bearer abc\ndef", shouldFail: true, description: "newline in token"},
{name: "leading_space_in_token", header: "Bearer abc", shouldFail: true, description: "leading space in token after scheme space"},
{name: "trailing_space_in_token", header: "Bearer abc ", shouldFail: true, description: "trailing space in token"},
{name: "comma_in_token", header: "Bearer abc,def", shouldFail: true, description: "comma in token"},
{name: "semicolon_in_token", header: "Bearer abc;def", shouldFail: true, description: "semicolon in token"},
{name: "quote_in_token", header: "Bearer abc\"def", shouldFail: true, description: "quote in token"},
{name: "bracket_in_token", header: "Bearer abc[def", shouldFail: true, description: "bracket in token"},
{name: "equals_at_start", header: "Bearer =abc", shouldFail: true, description: "equals at start of token"},
{name: "equals_in_middle", header: "Bearer ab=cd", shouldFail: true, description: "equals in middle of token"},
{name: "valid_equals_at_end", header: "Bearer abc=", shouldFail: false, description: "valid equals at end"},
{name: "valid_double_equals", header: "Bearer abc==", shouldFail: false, description: "valid double equals at end"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set(fiber.HeaderAuthorization, tc.header)
token, err := FromAuthHeader("Bearer").Extract(ctx)
if tc.shouldFail {
require.Error(t, err, "Expected error for %s", tc.description)
require.ErrorIs(t, err, ErrNotFound)
require.Empty(t, token)
} else {
require.NoError(t, err, "Expected no error for %s", tc.description)
require.NotEmpty(t, token)
}
})
}
}
// go test -run Test_Extractor_FromAuthHeader_NoScheme
func Test_Extractor_FromAuthHeader_NoScheme(t *testing.T) {
t.Parallel()
t.Run("returns_header_value", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.Set(fiber.HeaderAuthorization, "some-token-value")
extractor := FromAuthHeader("") // No scheme
token, err := extractor.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "some-token-value", token)
// Verify metadata
require.Equal(t, SourceAuthHeader, extractor.Source)
require.Equal(t, fiber.HeaderAuthorization, extractor.Key)
require.Empty(t, extractor.AuthScheme)
})
t.Run("empty_header_returns_not_found", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
// No Authorization header set
extractor := FromAuthHeader("") // No scheme
token, err := extractor.Extract(ctx)
require.Empty(t, token)
require.ErrorIs(t, err, ErrNotFound)
})
}
// go test -run Test_Extractor_Chain_NilFunctions
func Test_Extractor_Chain_NilFunctions(t *testing.T) {
t.Parallel()
app := fiber.New()
// Test chain with nil extractor functions
nilExtractor := Extractor{
Extract: nil,
Key: "nil",
Source: SourceCustom,
}
validExtractor := Extractor{
Extract: func(_ fiber.Ctx) (string, error) {
return "valid-token", nil
},
Key: "valid",
Source: SourceCustom,
}
chainExtractor := Chain(nilExtractor, validExtractor)
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
token, err := chainExtractor.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "valid-token", token)
}
// go test -run Test_Extractor_Chain_AllErrors
func Test_Extractor_Chain_AllErrors(t *testing.T) {
t.Parallel()
app := fiber.New()
// Test chain where all extractors return errors
errorExtractor1 := Extractor{
Extract: func(_ fiber.Ctx) (string, error) {
return "", fiber.NewError(fiber.StatusUnauthorized, "First auth error")
},
Key: "error1",
Source: SourceCustom,
}
errorExtractor2 := Extractor{
Extract: func(_ fiber.Ctx) (string, error) {
return "", fiber.NewError(fiber.StatusForbidden, "Second auth error")
},
Key: "error2",
Source: SourceCustom,
}
chainExtractor := Chain(errorExtractor1, errorExtractor2)
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
token, err := chainExtractor.Extract(ctx)
require.Empty(t, token)
require.Error(t, err)
require.Contains(t, err.Error(), "Second auth error") // Should return last error
var fe *fiber.Error
require.ErrorAs(t, err, &fe)
require.Equal(t, fiber.StatusForbidden, fe.Code)
}
// go test -run Test_Extractor_Chain_MixedScenarios
func Test_Extractor_Chain_MixedScenarios(t *testing.T) {
t.Parallel()
// Define reusable extractors
failingExtractor := Extractor{
Extract: func(_ fiber.Ctx) (string, error) {
return "", ErrNotFound
},
Key: "fail",
Source: SourceCustom,
}
errorExtractor := Extractor{
Extract: func(_ fiber.Ctx) (string, error) {
return "", fiber.NewError(fiber.StatusBadRequest, "Bad request")
},
Key: "error",
Source: SourceCustom,
}
successExtractor := Extractor{
Extract: func(_ fiber.Ctx) (string, error) {
return "success", nil
},
Key: "success",
Source: SourceCustom,
}
t.Run("error_then_success", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
chain := Chain(errorExtractor, successExtractor)
token, err := chain.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "success", token)
})
t.Run("fail_then_error_then_success", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
chain := Chain(failingExtractor, errorExtractor, successExtractor)
token, err := chain.Extract(ctx)
require.NoError(t, err)
require.Equal(t, "success", token)
})
t.Run("fail_then_error", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
chain := Chain(failingExtractor, errorExtractor)
token, err := chain.Extract(ctx)
require.Empty(t, token)
require.Error(t, err)
require.Contains(t, err.Error(), "Bad request")
})
}
// go test -run Test_Extractor_SourceTypes
func Test_Extractor_SourceTypes(t *testing.T) {
t.Parallel()
t.Run("individual_extractor_sources", func(t *testing.T) {
t.Parallel()
// Test that all source types are properly set
require.Equal(t, SourceHeader, FromHeader("test").Source)
require.Equal(t, SourceAuthHeader, FromAuthHeader("Bearer").Source)
require.Equal(t, SourceAuthHeader, FromAuthHeader("").Source) // Empty scheme should still be SourceAuthHeader
require.Equal(t, SourceForm, FromForm("test").Source)
require.Equal(t, SourceQuery, FromQuery("test").Source)
require.Equal(t, SourceParam, FromParam("test").Source)
require.Equal(t, SourceCookie, FromCookie("test").Source)
require.Equal(t, SourceCustom, FromCustom("test", func(_ fiber.Ctx) (string, error) { return "test", nil }).Source)
})
t.Run("chain_source_metadata", func(t *testing.T) {
t.Parallel()
// Test chain source (should use first extractor's source)
chain := Chain(FromHeader("X-Test"), FromQuery("test"))
require.Equal(t, SourceHeader, chain.Source)
require.Equal(t, "X-Test", chain.Key)
})
}
// go test -run Test_Extractor_URL_Encoded
func Test_Extractor_URL_Encoded(t *testing.T) {
t.Parallel()
t.Run("FromQuery_with_spaces", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().SetRequestURI("/?token=token%20with%20spaces")
token, err := FromQuery("token").Extract(ctx)
require.NoError(t, err)
require.Equal(t, "token with spaces", token) // Should be URL-decoded automatically by fasthttp
})
t.Run("FromForm_with_plus", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
ctx.Request().Header.SetContentType(fiber.MIMEApplicationForm)
ctx.Request().Header.SetMethod(fiber.MethodPost)
ctx.Request().SetBodyString("token=token%2Bwith%2Bplus")
token, err := FromForm("token").Extract(ctx)
require.NoError(t, err)
require.Equal(t, "token+with+plus", token) // URL-decoded
})
t.Run("FromQuery_base64_encoded", func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
base64Value := "cGFzc3dvcmQ%3D" // URL-encoded base64 "cGFzc3dvcmQ="
ctx.Request().SetRequestURI("/?token=" + base64Value)
token, err := FromQuery("token").Extract(ctx)
require.NoError(t, err)
require.Equal(t, "cGFzc3dvcmQ=", token) // Should be URL-decoded
})
t.Run("FromParam_with_slashes", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test/:token", func(c fiber.Ctx) error {
token, extractErr := FromParam("token").Extract(c)
require.NoError(t, extractErr)
require.Equal(t, "token/with/slashes", token)
return nil
})
_, err := app.Test(newRequest(fiber.MethodGet, "/test/token%2Fwith%2Fslashes"))
require.NoError(t, err)
})
}
func Test_isValidToken68(t *testing.T) {
t.Parallel()
cases := []struct {
name string
token string
want bool
}{
{name: "empty string", token: "", want: false},
{name: "single uppercase", token: "A", want: true},
{name: "single lowercase", token: "a", want: true},
{name: "single digit", token: "0", want: true},
{name: "all allowed symbols except =", token: "-._~+/", want: true},
{name: "letters and digits", token: "token68", want: true},
{name: "equals at end", token: "token=", want: true},
{name: "multiple equals", token: "token==", want: true},
{name: "equals at start", token: "=token", want: false},
{name: "equals in middle", token: "tok=en", want: false},
{name: "equals not at end with other chars", token: "token=extra", want: false},
{name: "space in token", token: "token space", want: false},
{name: "tab character in token", token: "token\ttab", want: false},
{name: "invalid symbol", token: "token@", want: false},
{name: "valid token68", token: "token68", want: true},
{token: "token68=", want: true, name: "valid token68 with equals at end"},
{token: "token68==", want: true, name: "multiple equals at end"},
{token: "token68=extra", want: false, name: "equals followed by extra chars"},
{token: "T0ken-._~+/=", want: true, name: "all allowed chars with equals at end"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := isValidToken68(tc.token)
if got != tc.want {
t.Errorf("isValidToken68(%q) = %v, want %v", tc.token, got, tc.want)
}
})
}
}
================================================
FILE: go.mod
================================================
module github.com/gofiber/fiber/v3
go 1.25.0
require (
github.com/gofiber/schema v1.7.0
github.com/gofiber/utils/v2 v2.0.2
github.com/google/uuid v1.6.0
github.com/mattn/go-colorable v0.1.14
github.com/mattn/go-isatty v0.0.20
github.com/shamaton/msgpack/v3 v3.1.0
github.com/stretchr/testify v1.11.1
github.com/tinylib/msgp v1.6.3
github.com/valyala/bytebufferpool v1.0.0
github.com/valyala/fasthttp v1.69.0
golang.org/x/crypto v0.49.0
)
require (
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fxamacker/cbor/v2 v2.9.0 // direct
github.com/klauspost/compress v1.18.4 // indirect
github.com/philhofer/fwd v1.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/net v0.52.0
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0
gopkg.in/yaml.v3 v3.0.1 // indirect
)
================================================
FILE: go.sum
================================================
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/gofiber/schema v1.7.0 h1:yNM+FNRZjyYEli9Ey0AXRBrAY9jTnb+kmGs3lJGPvKg=
github.com/gofiber/schema v1.7.0/go.mod h1:A/X5Ffyru4p9eBdp99qu+nzviHzQiZ7odLT+TwxWhbk=
github.com/gofiber/utils/v2 v2.0.2 h1:ShRRssz0F3AhTlAQcuEj54OEDtWF7+HJDwEi/aa6QLI=
github.com/gofiber/utils/v2 v2.0.2/go.mod h1:+9Ub4NqQ+IaJoTliq5LfdmOJAA/Hzwf4pXOxOa3RrJ0=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/shamaton/msgpack/v3 v3.1.0 h1:jsk0vEAqVvvS9+fTZ5/EcQ9tz860c9pWxJ4Iwecz8gU=
github.com/shamaton/msgpack/v3 v3.1.0/go.mod h1:DcQG8jrdrQCIxr3HlMYkiXdMhK+KfN2CitkyzsQV4uc=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tinylib/msgp v1.6.3 h1:bCSxiTz386UTgyT1i0MSCvdbWjVW+8sG3PjkGsZQt4s=
github.com/tinylib/msgp v1.6.3/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI=
github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
================================================
FILE: group.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"fmt"
"reflect"
)
// Group represents a collection of routes that share middleware and a common
// path prefix.
type Group struct {
app *App
parentGroup *Group
name string
Prefix string
anyRouteDefined bool
}
// Name Assign name to specific route or group itself.
//
// If this method is used before any route added to group, it'll set group name and OnGroupNameHook will be used.
// Otherwise, it'll set route name and OnName hook will be used.
func (grp *Group) Name(name string) Router {
if grp.anyRouteDefined {
grp.app.Name(name)
return grp
}
grp.app.mutex.Lock()
if grp.parentGroup != nil {
grp.name = grp.parentGroup.name + name
} else {
grp.name = name
}
if err := grp.app.hooks.executeOnGroupNameHooks(*grp); err != nil {
panic(err)
}
grp.app.mutex.Unlock()
return grp
}
// Use registers a middleware route that will match requests
// with the provided prefix (which is optional and defaults to "/").
// Also, you can pass another app instance as a sub-router along a routing path.
// It's very useful to split up a large API as many independent routers and
// compose them as a single service using Use. The fiber's error handler and
// any of the fiber's sub apps are added to the application's error handlers
// to be invoked on errors that happen within the prefix route.
//
// app.Use(func(c fiber.Ctx) error {
// return c.Next()
// })
// app.Use("/api", func(c fiber.Ctx) error {
// return c.Next()
// })
// app.Use("/api", handler, func(c fiber.Ctx) error {
// return c.Next()
// })
// subApp := fiber.New()
// app.Use("/mounted-path", subApp)
//
// This method will match all HTTP verbs: GET, POST, PUT, HEAD etc...
func (grp *Group) Use(args ...any) Router {
var subApp *App
var prefix string
var prefixes []string
var handlers []Handler
for i := range args {
switch arg := args[i].(type) {
case string:
prefix = arg
case *App:
subApp = arg
case []string:
prefixes = arg
default:
handler, ok := toFiberHandler(arg)
if !ok {
panic(fmt.Sprintf("use: invalid handler %v\n", reflect.TypeOf(arg)))
}
handlers = append(handlers, handler)
}
}
if len(prefixes) == 0 {
prefixes = append(prefixes, prefix)
}
for _, prefix := range prefixes {
if subApp != nil {
return grp.mount(prefix, subApp)
}
grp.app.register([]string{methodUse}, getGroupPath(grp.Prefix, prefix), grp, handlers...)
}
if !grp.anyRouteDefined {
grp.anyRouteDefined = true
}
return grp
}
// Get registers a route for GET methods that requests a representation
// of the specified resource. Requests using GET should only retrieve data.
func (grp *Group) Get(path string, handler any, handlers ...any) Router {
return grp.Add([]string{MethodGet}, path, handler, handlers...)
}
// Head registers a route for HEAD methods that asks for a response identical
// to that of a GET request, but without the response body.
func (grp *Group) Head(path string, handler any, handlers ...any) Router {
return grp.Add([]string{MethodHead}, path, handler, handlers...)
}
// Post registers a route for POST methods that is used to submit an entity to the
// specified resource, often causing a change in state or side effects on the server.
func (grp *Group) Post(path string, handler any, handlers ...any) Router {
return grp.Add([]string{MethodPost}, path, handler, handlers...)
}
// Put registers a route for PUT methods that replaces all current representations
// of the target resource with the request payload.
func (grp *Group) Put(path string, handler any, handlers ...any) Router {
return grp.Add([]string{MethodPut}, path, handler, handlers...)
}
// Delete registers a route for DELETE methods that deletes the specified resource.
func (grp *Group) Delete(path string, handler any, handlers ...any) Router {
return grp.Add([]string{MethodDelete}, path, handler, handlers...)
}
// Connect registers a route for CONNECT methods that establishes a tunnel to the
// server identified by the target resource.
func (grp *Group) Connect(path string, handler any, handlers ...any) Router {
return grp.Add([]string{MethodConnect}, path, handler, handlers...)
}
// Options registers a route for OPTIONS methods that is used to describe the
// communication options for the target resource.
func (grp *Group) Options(path string, handler any, handlers ...any) Router {
return grp.Add([]string{MethodOptions}, path, handler, handlers...)
}
// Trace registers a route for TRACE methods that performs a message loop-back
// test along the path to the target resource.
func (grp *Group) Trace(path string, handler any, handlers ...any) Router {
return grp.Add([]string{MethodTrace}, path, handler, handlers...)
}
// Patch registers a route for PATCH methods that is used to apply partial
// modifications to a resource.
func (grp *Group) Patch(path string, handler any, handlers ...any) Router {
return grp.Add([]string{MethodPatch}, path, handler, handlers...)
}
// Add allows you to specify multiple HTTP methods to register a route.
// The provided handlers are executed in order, starting with `handler` and then the variadic `handlers`.
func (grp *Group) Add(methods []string, path string, handler any, handlers ...any) Router {
converted := collectHandlers("group", append([]any{handler}, handlers...)...)
grp.app.register(methods, getGroupPath(grp.Prefix, path), grp, converted...)
if !grp.anyRouteDefined {
grp.anyRouteDefined = true
}
return grp
}
// All will register the handler on all HTTP methods
func (grp *Group) All(path string, handler any, handlers ...any) Router {
_ = grp.Add(grp.app.config.RequestMethods, path, handler, handlers...)
return grp
}
// Group is used for Routes with common prefix to define a new sub-router with optional middleware.
//
// api := app.Group("/api")
// api.Get("/users", handler)
func (grp *Group) Group(prefix string, handlers ...any) Router {
prefix = getGroupPath(grp.Prefix, prefix)
if len(handlers) > 0 {
converted := collectHandlers("group", handlers...)
grp.app.register([]string{methodUse}, prefix, grp, converted...)
}
// Create new group
newGrp := &Group{Prefix: prefix, app: grp.app, parentGroup: grp}
if err := grp.app.hooks.executeOnGroupHooks(*newGrp); err != nil {
panic(err)
}
return newGrp
}
// RouteChain creates a Registering instance scoped to the group's prefix,
// allowing chained route declarations for the same path.
func (grp *Group) RouteChain(path string) Register {
// Create new group
register := &Registering{app: grp.app, group: grp, path: getGroupPath(grp.Prefix, path)}
return register
}
// Route is used to define routes with a common prefix inside the supplied
// function. It mirrors the legacy helper and reuses the Group method to create
// a sub-router.
func (grp *Group) Route(prefix string, fn func(router Router), name ...string) Router {
if fn == nil {
panic("route handler 'fn' cannot be nil")
}
// Create new group
group := grp.Group(prefix)
if len(name) > 0 {
group.Name(name[0])
}
// Define routes
fn(group)
return group
}
================================================
FILE: helpers.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"reflect"
"slices"
"strconv"
"strings"
"sync"
"time"
"unsafe"
"github.com/gofiber/utils/v2"
utilsbytes "github.com/gofiber/utils/v2/bytes"
"github.com/gofiber/fiber/v3/log"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
// acceptedType is a struct that holds the parsed value of an Accept header
// along with quality, specificity, parameters, and order.
// Used for sorting accept headers.
type acceptedType struct {
params headerParams
spec string
quality float64
specificity int
order int
}
const noCacheValue = "no-cache"
// Pre-allocated byte slices for accept header parsing
var (
semicolonQEquals = []byte(";q=")
wildcardAll = []byte("*/*")
wildcardSuffix = []byte("/*")
)
type headerParams map[string][]byte
// ValueFromContext retrieves a value stored under key from supported context types.
//
// Supported context types:
// - Ctx (including CustomCtx implementations)
// - *fasthttp.RequestCtx
// - context.Context
func ValueFromContext[T any](ctx, key any) (T, bool) {
switch typed := ctx.(type) {
case Ctx:
val, ok := typed.Locals(key).(T)
return val, ok
case *fasthttp.RequestCtx:
val, ok := typed.UserValue(key).(T)
return val, ok
case context.Context:
val, ok := typed.Value(key).(T)
return val, ok
default:
var zero T
return zero, false
}
}
// StoreInContext stores key/value in both Fiber locals and request context.
//
// This is useful when values need to be available via both c.Locals() and
// context.Context lookups throughout middleware and handlers.
func StoreInContext(c Ctx, key, value any) {
c.Locals(key, value)
if c.App().config.PassLocalsToContext {
c.SetContext(context.WithValue(c.Context(), key, value))
}
}
// getTLSConfig returns a net listener's tls config
func getTLSConfig(ln net.Listener) *tls.Config {
if ln == nil {
return nil
}
type tlsConfigProvider interface {
TLSConfig() *tls.Config
}
type configProvider interface {
Config() *tls.Config
}
if provider, ok := ln.(tlsConfigProvider); ok {
return provider.TLSConfig()
}
if provider, ok := ln.(configProvider); ok {
return provider.Config()
}
pointer := reflect.ValueOf(ln)
if !pointer.IsValid() {
return nil
}
// Reflection fallback for listeners that do not expose a TLS config method.
val := reflect.Indirect(pointer)
if !val.IsValid() {
return nil
}
field := val.FieldByName("config")
if !field.IsValid() {
return nil
}
if field.Type() != reflect.TypeFor[*tls.Config]() {
return nil
}
if field.CanInterface() {
if cfg, ok := field.Interface().(*tls.Config); ok {
return cfg
}
return nil
}
if !field.CanAddr() {
return nil
}
value := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem() //nolint:gosec // Access to unexported field is required for listeners that don't expose TLS config methods.
if !value.IsValid() {
return nil
}
cfg, ok := value.Interface().(*tls.Config)
if !ok {
return nil
}
return cfg
}
// readContent opens a named file and read content from it
func readContent(rf io.ReaderFrom, name string) (int64, error) {
// Read file
f, err := os.Open(filepath.Clean(name))
if err != nil {
return 0, fmt.Errorf("failed to open: %w", err)
}
defer func() {
if err = f.Close(); err != nil {
log.Errorf("Error closing file: %s", err)
}
}()
n, readErr := rf.ReadFrom(f)
if readErr != nil {
return n, fmt.Errorf("failed to read: %w", readErr)
}
return n, nil
}
// quoteString escapes special characters using percent-encoding.
// Non-ASCII bytes are encoded as well so the result is always ASCII.
func (app *App) quoteString(raw string) string {
bb := bytebufferpool.Get()
quoted := app.toString(fasthttp.AppendQuotedArg(bb.B, app.toBytes(raw)))
bytebufferpool.Put(bb)
return quoted
}
// quoteRawString escapes only characters that need quoting according to
// https://www.rfc-editor.org/rfc/rfc9110#section-5.6.4 so the result may
// contain non-ASCII bytes.
func (app *App) quoteRawString(raw string) string {
const hex = "0123456789ABCDEF"
bb := bytebufferpool.Get()
defer bytebufferpool.Put(bb)
for i := 0; i < len(raw); i++ {
c := raw[i]
switch {
case c == '\\' || c == '"':
// escape backslash and quote
bb.B = append(bb.B, '\\', c)
case c == '\n':
bb.B = append(bb.B, '\\', 'n')
case c == '\r':
bb.B = append(bb.B, '\\', 'r')
case c < 0x20 || c == 0x7f:
// percent-encode control and DEL
bb.B = append(bb.B,
'%',
hex[c>>4],
hex[c&0x0f],
)
default:
bb.B = append(bb.B, c)
}
}
return app.toString(bb.B)
}
// isASCII reports whether the provided string contains only ASCII characters.
// See: https://www.rfc-editor.org/rfc/rfc0020
func (*App) isASCII(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] > 127 {
return false
}
}
return true
}
// uniqueRouteStack drop all not unique routes from the slice
func uniqueRouteStack(stack []*Route) []*Route {
m := make(map[*Route]struct{}, len(stack))
unique := make([]*Route, 0, len(stack))
for _, v := range stack {
if _, ok := m[v]; !ok {
m[v] = struct{}{}
unique = append(unique, v)
}
}
return unique
}
// defaultString returns the value or a default value if it is set
func defaultString(value string, defaultValue []string) string {
if value == "" && len(defaultValue) > 0 {
return defaultValue[0]
}
return value
}
func getGroupPath(prefix, path string) string {
if path == "" {
return prefix
}
if path[0] != '/' {
path = "/" + path
}
return utils.TrimRight(prefix, '/') + path
}
// acceptsOffer determines if an offer matches a given specification.
// It supports a trailing '*' wildcard and performs case-insensitive exact matching.
// Returns true if the offer matches the specification, false otherwise.
func acceptsOffer(spec, offer string, _ headerParams) bool {
if len(spec) >= 1 && spec[len(spec)-1] == '*' {
prefix := spec[:len(spec)-1]
if len(offer) < len(prefix) {
return false
}
return utils.EqualFold(prefix, offer[:len(prefix)])
}
return utils.EqualFold(spec, offer)
}
// acceptsLanguageOfferBasic determines if a language tag offer matches a range
// according to RFC 4647 Basic Filtering.
// A match occurs if the range exactly equals the tag or is a prefix of the tag
// followed by a hyphen. The comparison is case-insensitive. Only a single "*"
// as the entire range is allowed. Any "*" appearing after a hyphen renders the
// range invalid and will not match.
func acceptsLanguageOfferBasic(spec, offer string, _ headerParams) bool {
if spec == "*" {
return true
}
if strings.IndexByte(spec, '*') >= 0 {
return false
}
if utils.EqualFold(spec, offer) {
return true
}
return len(offer) > len(spec) &&
utils.EqualFold(offer[:len(spec)], spec) &&
offer[len(spec)] == '-'
}
// acceptsLanguageOfferExtended determines if a language tag offer matches a
// range according to RFC 4647 Extended Filtering (§3.3.2).
// - Case-insensitive comparisons
// - '*' matches zero or more subtags (can "slide")
// - Unspecified subtags are treated like '*' (so trailing/extraneous tag subtags are fine)
// - Matching fails if sliding encounters a singleton (incl. 'x')
func acceptsLanguageOfferExtended(spec, offer string, _ headerParams) bool {
if spec == "*" {
return true
}
if spec == "" || offer == "" {
return false
}
// Use stack-allocated arrays to avoid heap allocations for typical language tags
var rsBuf, tsBuf [8]string
rs := rsBuf[:0]
ts := tsBuf[:0]
// Parse spec subtags without allocation for typical cases
for s := range strings.SplitSeq(spec, "-") {
rs = append(rs, s)
}
// Parse offer subtags without allocation for typical cases
for s := range strings.SplitSeq(offer, "-") {
ts = append(ts, s)
}
// Step 2: first subtag must match (or be '*')
if rs[0] != "*" && !utils.EqualFold(rs[0], ts[0]) {
return false
}
i, j := 1, 1 // i = range index, j = tag index
for i < len(rs) {
if rs[i] == "*" { // 3.A: '*' matches zero or more subtags
i++
continue
}
if j >= len(ts) { // 3.B: ran out of tag subtags
return false
}
if utils.EqualFold(rs[i], ts[j]) { // 3.C: exact subtag match
i++
j++
continue
}
// 3.D: singleton barrier (one letter or digit, incl. 'x')
if len(ts[j]) == 1 {
return false
}
// 3.E: slide forward in the tag and try again
j++
}
// 4: matched all range subtags
return true
}
// acceptsOfferType This function determines if an offer type matches a given specification.
// It checks if the specification is equal to */* (i.e., all types are accepted).
// It gets the MIME type of the offer (either from the offer itself or by its file extension).
// It checks if the offer MIME type matches the specification MIME type or if the specification is of the form /* and the offer MIME type has the same MIME type.
// It checks if the offer contains every parameter present in the specification.
// Returns true if the offer type matches the specification, false otherwise.
func acceptsOfferType(spec, offerType string, specParams headerParams) bool {
var offerMime, offerParams string
if i := strings.IndexByte(offerType, ';'); i == -1 {
offerMime = offerType
} else {
offerMime = offerType[:i]
offerParams = offerType[i:]
}
// Accept: */*
if spec == "*/*" {
return paramsMatch(specParams, offerParams)
}
var mimetype string
if strings.IndexByte(offerMime, '/') != -1 {
mimetype = offerMime // MIME type
} else {
mimetype = utils.GetMIME(offerMime) // extension
}
if spec == mimetype {
// Accept: /
return paramsMatch(specParams, offerParams)
}
s := strings.IndexByte(mimetype, '/')
specSlash := strings.IndexByte(spec, '/')
// Accept: /*
if s != -1 && specSlash != -1 {
if utils.EqualFold(spec[:specSlash], mimetype[:s]) && (spec[specSlash:] == "/*" || mimetype[s:] == "/*") {
return paramsMatch(specParams, offerParams)
}
}
return false
}
// paramsMatch returns whether offerParams contains all parameters present in specParams.
// Matching is case-insensitive, and surrounding quotes are stripped.
// To align with the behavior of res.format from Express, the order of parameters is
// ignored, and if a parameter is specified twice in the incoming Accept, the last
// provided value is given precedence.
// In the case of quoted values, RFC 9110 says that we must treat any character escaped
// by a backslash as equivalent to the character itself (e.g., "a\aa" is equivalent to "aaa").
// For the sake of simplicity, we forgo this and compare the value as-is. Besides, it would
// be highly unusual for a client to escape something other than a double quote or backslash.
// See https://www.rfc-editor.org/rfc/rfc9110#name-parameters
func paramsMatch(specParamStr headerParams, offerParams string) bool {
if len(specParamStr) == 0 {
return true
}
allSpecParamsMatch := true
for specParam, specVal := range specParamStr {
foundParam := false
fasthttp.VisitHeaderParams(utils.UnsafeBytes(offerParams), func(key, value []byte) bool {
if utils.EqualFold(specParam, utils.UnsafeString(key)) {
foundParam = true
unescaped, err := unescapeHeaderValue(value)
if err != nil {
allSpecParamsMatch = false
return false
}
allSpecParamsMatch = utils.EqualFold(specVal, unescaped)
return false
}
return true
})
if !foundParam || !allSpecParamsMatch {
return false
}
}
return allSpecParamsMatch
}
// getSplicedStrList function takes a string and a string slice as an argument, divides the string into different
// elements divided by ',' and stores these elements in the string slice.
// It returns the populated string slice as an output.
//
// If the given slice hasn't enough space, it will allocate more and return.
func getSplicedStrList(headerValue string, dst []string) []string {
if headerValue == "" {
return nil
}
dst = dst[:0]
segmentStart := 0
for i := 0; i < len(headerValue); i++ {
if headerValue[i] == ',' {
dst = append(dst, utils.TrimSpace(headerValue[segmentStart:i]))
segmentStart = i + 1
}
}
dst = append(dst, utils.TrimSpace(headerValue[segmentStart:]))
return dst
}
func joinHeaderValues(headers [][]byte) []byte {
switch len(headers) {
case 0:
return nil
case 1:
return headers[0]
default:
return bytes.Join(headers, []byte{','})
}
}
func unescapeHeaderValue(v []byte) ([]byte, error) {
if bytes.IndexByte(v, '\\') == -1 {
return v, nil
}
res := make([]byte, 0, len(v))
escaping := false
for i, c := range v {
if escaping {
res = append(res, c)
escaping = false
continue
}
if c == '\\' {
// invalid escape at end of string
if i == len(v)-1 {
return nil, errInvalidEscapeSequence
}
escaping = true
continue
}
res = append(res, c)
}
if escaping {
return nil, errInvalidEscapeSequence
}
return res, nil
}
// forEachMediaRange parses an Accept or Content-Type header, calling functor
// on each media range.
// See: https://www.rfc-editor.org/rfc/rfc9110#name-content-negotiation-fields
func forEachMediaRange(header []byte, functor func([]byte)) {
hasDQuote := bytes.IndexByte(header, '"') != -1
for len(header) > 0 {
n := 0
header = utils.TrimLeft(header, ' ')
quotes := 0
escaping := false
if hasDQuote {
// Complex case. We need to keep track of quotes and quoted-pairs (i.e., characters escaped with \ )
loop:
for n < len(header) {
switch header[n] {
case ',':
if quotes%2 == 0 {
break loop
}
case '"':
if !escaping {
quotes++
}
case '\\':
if quotes%2 == 1 {
escaping = !escaping
}
default:
// all other characters are ignored
}
n++
}
} else {
// Simple case. Just look for the next comma.
if n = bytes.IndexByte(header, ','); n == -1 {
n = len(header)
}
}
functor(header[:n])
if n >= len(header) {
return
}
header = header[n+1:]
}
}
// Pool for headerParams instances. The headerParams object *must*
// be cleared before being returned to the pool.
var headerParamPool = sync.Pool{
New: func() any {
return make(headerParams)
},
}
// getOffer return valid offer for header negotiation.
func getOffer(header []byte, isAccepted func(spec, offer string, specParams headerParams) bool, offers ...string) string {
if len(offers) == 0 {
return ""
}
if len(header) == 0 {
return offers[0]
}
acceptedTypes := make([]acceptedType, 0, 8)
order := 0
// Parse header and get accepted types with their quality and specificity
// See: https://www.rfc-editor.org/rfc/rfc9110#name-content-negotiation-fields
forEachMediaRange(header, func(accept []byte) {
order++
spec, quality := accept, 1.0
var params headerParams
if i := bytes.IndexByte(accept, ';'); i != -1 {
spec = accept[:i]
// Optimized quality parsing
qIndex := i + 3
if bytes.HasPrefix(accept[i:], semicolonQEquals) && bytes.IndexByte(accept[qIndex:], ';') == -1 {
if q, err := fasthttp.ParseUfloat(accept[qIndex:]); err == nil {
quality = q
}
} else {
params, _ = headerParamPool.Get().(headerParams) //nolint:errcheck // only contains headerParams
for k := range params {
delete(params, k)
}
fasthttp.VisitHeaderParams(accept[i:], func(key, value []byte) bool {
if len(key) == 1 && key[0] == 'q' {
if q, err := fasthttp.ParseUfloat(value); err == nil {
quality = q
}
return false
}
lowerKey := utils.UnsafeString(utilsbytes.UnsafeToLower(key))
val, err := unescapeHeaderValue(value)
if err != nil {
return true
}
params[lowerKey] = val
return true
})
}
// Skip this accept type if quality is 0.0
// See: https://www.rfc-editor.org/rfc/rfc9110#quality.values
if quality == 0.0 {
return
}
}
spec = utils.TrimSpace(spec)
// Determine specificity
var specificity int
// check for wildcard this could be a mime */* or a wildcard character *
switch {
case len(spec) == 1 && spec[0] == '*':
specificity = 1
case bytes.Equal(spec, wildcardAll):
specificity = 1
case bytes.HasSuffix(spec, wildcardSuffix):
specificity = 2
case bytes.IndexByte(spec, '/') != -1:
specificity = 3
default:
specificity = 4
}
// Add to accepted types
acceptedTypes = append(acceptedTypes, acceptedType{
spec: utils.UnsafeString(spec),
quality: quality,
specificity: specificity,
order: order,
params: params,
})
})
if len(acceptedTypes) > 1 {
// Sort accepted types by quality and specificity, preserving order of equal elements
sortAcceptedTypes(acceptedTypes)
}
// Find the first offer that matches the accepted types
for _, acceptedType := range acceptedTypes {
for _, offer := range offers {
if offer == "" {
continue
}
if isAccepted(acceptedType.spec, offer, acceptedType.params) {
if acceptedType.params != nil {
headerParamPool.Put(acceptedType.params)
}
return offer
}
}
if acceptedType.params != nil {
headerParamPool.Put(acceptedType.params)
}
}
return ""
}
// sortAcceptedTypes sorts accepted types by quality and specificity, preserving order of equal elements
// A type with parameters has higher priority than an equivalent one without parameters.
// e.g., text/html;a=1;b=2 comes before text/html;a=1
// See: https://www.rfc-editor.org/rfc/rfc9110#name-content-negotiation-fields
func sortAcceptedTypes(at []acceptedType) {
for i := 1; i < len(at); i++ {
lo, hi := 0, i-1
for lo <= hi {
mid := (lo + hi) / 2
if at[i].quality < at[mid].quality ||
(at[i].quality == at[mid].quality && at[i].specificity < at[mid].specificity) ||
(at[i].quality == at[mid].quality && at[i].specificity == at[mid].specificity && len(at[i].params) < len(at[mid].params)) ||
(at[i].quality == at[mid].quality && at[i].specificity == at[mid].specificity && len(at[i].params) == len(at[mid].params) && at[i].order > at[mid].order) {
lo = mid + 1
} else {
hi = mid - 1
}
}
for j := i; j > lo; j-- {
at[j-1], at[j] = at[j], at[j-1]
}
}
}
// normalizeEtag validates an entity tag and returns the
// value without quotes. weak is true if the tag has the "W/" prefix.
func normalizeEtag(t string) (value string, weak, ok bool) { //nolint:nonamedreturns // gocritic unnamedResult requires naming the parsed ETag components
weak = strings.HasPrefix(t, "W/")
if weak {
t = t[2:]
}
if len(t) < 2 || t[0] != '"' || t[len(t)-1] != '"' {
return "", weak, false
}
return t[1 : len(t)-1], weak, true
}
// matchEtag performs a weak comparison of entity tags according to
// RFC 9110 §8.8.3.2. The weak indicator ("W/") is ignored, but both tags must
// be properly quoted. Invalid tags result in a mismatch.
func matchEtag(s, etag string) bool {
n1, _, ok1 := normalizeEtag(s)
n2, _, ok2 := normalizeEtag(etag)
if !ok1 || !ok2 {
return false
}
return n1 == n2
}
// matchEtagStrong performs a strong entity-tag comparison following
// RFC 9110 §8.8.3.1. A weak tag never matches a strong one, even if the quoted
// values are identical.
func matchEtagStrong(s, etag string) bool {
n1, w1, ok1 := normalizeEtag(s)
n2, w2, ok2 := normalizeEtag(etag)
if !ok1 || !ok2 || w1 || w2 {
return false
}
return n1 == n2
}
// isEtagStale reports whether a response with the given ETag would be considered
// stale when presented with the raw If-None-Match header value. Comparison is
// weak as defined by RFC 9110 §8.8.3.2.
func (app *App) isEtagStale(etag string, noneMatchBytes []byte) bool {
var start, end int
header := utils.TrimSpace(app.toString(noneMatchBytes))
// Short-circuit the wildcard case: "*" never counts as stale.
if header == "*" {
return false
}
// Adapted from:
// https://github.com/jshttp/fresh/blob/master/index.js#L110
for i := range noneMatchBytes {
switch noneMatchBytes[i] {
case 0x20:
if start == end {
start = i + 1
end = i + 1
}
case 0x2c:
if matchEtag(app.toString(noneMatchBytes[start:end]), etag) {
return false
}
start = i + 1
end = i + 1
default:
end = i + 1
}
}
return !matchEtag(app.toString(noneMatchBytes[start:end]), etag)
}
func parseAddr(raw string) (host, port string) { //nolint:nonamedreturns // gocritic unnamedResult requires naming host and port parts for clarity
if raw == "" {
return "", ""
}
raw = utils.TrimSpace(raw)
// Handle IPv6 addresses enclosed in brackets as defined by RFC 3986
if strings.HasPrefix(raw, "[") {
if end := strings.IndexByte(raw, ']'); end != -1 {
host = raw[:end+1] // keep the closing ]
if len(raw) > end+1 && raw[end+1] == ':' {
return host, raw[end+2:]
}
return host, ""
}
}
// Everything else with a colon
if i := strings.LastIndexByte(raw, ':'); i != -1 {
host, port = raw[:i], raw[i+1:]
// If “host” still contains ':', we must have hit an un-bracketed IPv6
// literal. In that form a port is impossible, so treat the whole thing
// as host.
if strings.IndexByte(host, ':') >= 0 {
return raw, ""
}
return host, port
}
// No colon, nothing to split
return raw, ""
}
// isNoCache checks if the cacheControl header value contains a `no-cache` directive.
// Per RFC 9111 §5.2.2.4, no-cache can appear as either:
// - "no-cache" (applies to entire response)
// - "no-cache=field-name" (applies to specific header field)
// Both forms indicate the response should not be served from cache without revalidation.
func isNoCache(cacheControl string) bool {
n := len(cacheControl)
if n < len(noCacheValue) {
return false
}
const noCacheLen = len(noCacheValue)
const asciiCaseFold = byte(0x20)
for i := 0; i <= n-noCacheLen; i++ {
if (cacheControl[i] | asciiCaseFold) != 'n' {
continue
}
if !matchNoCacheToken(cacheControl, i) {
continue
}
if i > 0 && !isNoCacheDelimiter(cacheControl[i-1]) {
continue
}
// Handle: "no-cache", "no-cache, ...", "no-cache=...", "no-cache ,"
if i+noCacheLen == n {
return true
}
if isNoCacheDelimiter(cacheControl[i+noCacheLen]) || cacheControl[i+noCacheLen] == '=' {
return true
}
}
return false
}
func isNoCacheDelimiter(c byte) bool {
return c == ' ' || c == '\t' || c == ','
}
func matchNoCacheToken(s string, i int) bool {
// ASCII-only case-insensitive compare for "no-cache".
const asciiCaseFold = byte(0x20)
b := s[i:]
return (b[0]|asciiCaseFold) == 'n' &&
(b[1]|asciiCaseFold) == 'o' &&
b[2] == '-' &&
(b[3]|asciiCaseFold) == 'c' &&
(b[4]|asciiCaseFold) == 'a' &&
(b[5]|asciiCaseFold) == 'c' &&
(b[6]|asciiCaseFold) == 'h' &&
(b[7]|asciiCaseFold) == 'e'
}
var errTestConnClosed = errors.New("testConn is closed")
type testConn struct {
r bytes.Buffer
w bytes.Buffer
isClosed bool
sync.Mutex
}
// Read implements net.Conn by reading from the buffered input.
func (c *testConn) Read(b []byte) (int, error) {
c.Lock()
defer c.Unlock()
return c.r.Read(b) //nolint:wrapcheck // This must not be wrapped
}
// Write implements net.Conn by appending to the buffered output.
func (c *testConn) Write(b []byte) (int, error) {
c.Lock()
defer c.Unlock()
if c.isClosed {
return 0, errTestConnClosed
}
return c.w.Write(b) //nolint:wrapcheck // This must not be wrapped
}
// Close marks the connection as closed and prevents further writes.
func (c *testConn) Close() error {
c.Lock()
defer c.Unlock()
c.isClosed = true
return nil
}
// LocalAddr implements net.Conn and returns a placeholder address.
func (*testConn) LocalAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
// RemoteAddr implements net.Conn and returns a placeholder address.
func (*testConn) RemoteAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} }
// SetDeadline implements net.Conn but is a no-op for the in-memory connection.
func (*testConn) SetDeadline(_ time.Time) error { return nil }
// SetReadDeadline implements net.Conn but is a no-op for the in-memory connection.
func (*testConn) SetReadDeadline(_ time.Time) error { return nil }
// SetWriteDeadline implements net.Conn but is a no-op for the in-memory connection.
func (*testConn) SetWriteDeadline(_ time.Time) error { return nil }
func toStringImmutable(b []byte) string {
return string(b)
}
func toBytesImmutable(s string) []byte {
return []byte(s)
}
// HTTP methods and their unique INTs
func (app *App) methodInt(s string) int {
// For better performance
if len(app.configured.RequestMethods) == 0 {
switch s {
case MethodGet:
return methodGet
case MethodHead:
return methodHead
case MethodPost:
return methodPost
case MethodPut:
return methodPut
case MethodDelete:
return methodDelete
case MethodConnect:
return methodConnect
case MethodOptions:
return methodOptions
case MethodTrace:
return methodTrace
case MethodPatch:
return methodPatch
default:
return -1
}
}
// For method customization
return slices.Index(app.config.RequestMethods, s)
}
func (app *App) method(methodInt int) string {
return app.config.RequestMethods[methodInt]
}
// IsMethodSafe reports whether the HTTP method is considered safe.
// See https://datatracker.ietf.org/doc/html/rfc9110#section-9.2.1
func IsMethodSafe(m string) bool {
switch m {
case MethodGet,
MethodHead,
MethodOptions,
MethodTrace:
return true
default:
return false
}
}
// IsMethodIdempotent reports whether the HTTP method is considered idempotent.
// See https://datatracker.ietf.org/doc/html/rfc9110#section-9.2.2
func IsMethodIdempotent(m string) bool {
if IsMethodSafe(m) {
return true
}
switch m {
case MethodPut, MethodDelete:
return true
default:
return false
}
}
// Convert a string value to a specified type, handling errors and optional default values.
func Convert[T any](value string, converter func(string) (T, error), defaultValue ...T) (T, error) {
converted, err := converter(value)
if err != nil {
if len(defaultValue) > 0 {
return defaultValue[0], nil
}
return converted, fmt.Errorf("failed to convert: %w", err)
}
return converted, nil
}
var (
errParsedEmptyString = errors.New("parsed result is empty string")
errParsedEmptyBytes = errors.New("parsed result is empty bytes")
errParsedType = errors.New("unsupported generic type")
)
func genericParseType[V GenericType](str string) (V, error) {
var v V
switch any(v).(type) {
case int:
result, err := utils.ParseInt(str)
if err != nil {
return v, fmt.Errorf("failed to parse int: %w", err)
}
return any(int(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed
case int8:
result, err := utils.ParseInt8(str)
if err != nil {
return v, fmt.Errorf("failed to parse int8: %w", err)
}
return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed
case int16:
result, err := utils.ParseInt16(str)
if err != nil {
return v, fmt.Errorf("failed to parse int16: %w", err)
}
return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed
case int32:
result, err := utils.ParseInt32(str)
if err != nil {
return v, fmt.Errorf("failed to parse int32: %w", err)
}
return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed
case int64:
result, err := utils.ParseInt(str)
if err != nil {
return v, fmt.Errorf("failed to parse int64: %w", err)
}
return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed
case uint:
result, err := utils.ParseUint(str)
if err != nil {
return v, fmt.Errorf("failed to parse uint: %w", err)
}
return any(uint(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed
case uint8:
result, err := utils.ParseUint8(str)
if err != nil {
return v, fmt.Errorf("failed to parse uint8: %w", err)
}
return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed
case uint16:
result, err := utils.ParseUint16(str)
if err != nil {
return v, fmt.Errorf("failed to parse uint16: %w", err)
}
return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed
case uint32:
result, err := utils.ParseUint32(str)
if err != nil {
return v, fmt.Errorf("failed to parse uint32: %w", err)
}
return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed
case uint64:
result, err := utils.ParseUint(str)
if err != nil {
return v, fmt.Errorf("failed to parse uint64: %w", err)
}
return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed
case float32:
result, err := utils.ParseFloat32(str)
if err != nil {
return v, fmt.Errorf("failed to parse float32: %w", err)
}
return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed
case float64:
result, err := utils.ParseFloat64(str)
if err != nil {
return v, fmt.Errorf("failed to parse float64: %w", err)
}
return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed
case bool:
result, err := strconv.ParseBool(str)
if err != nil {
return v, fmt.Errorf("failed to parse bool: %w", err)
}
return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed
case string:
if str == "" {
return v, errParsedEmptyString
}
return any(str).(V), nil //nolint:errcheck,forcetypeassert // not needed
case []byte:
if str == "" {
return v, errParsedEmptyBytes
}
return any([]byte(str)).(V), nil //nolint:errcheck,forcetypeassert // not needed
default:
return v, errParsedType
}
}
// GenericType enumerates the values that can be parsed from strings by the
// generic helper functions.
type GenericType interface {
GenericTypeInteger | GenericTypeFloat | bool | string | []byte
}
// GenericTypeInteger is the union of all supported integer types.
type GenericTypeInteger interface {
GenericTypeIntegerSigned | GenericTypeIntegerUnsigned
}
// GenericTypeIntegerSigned is the union of supported signed integer types.
type GenericTypeIntegerSigned interface {
int | int8 | int16 | int32 | int64
}
// GenericTypeIntegerUnsigned is the union of supported unsigned integer types.
type GenericTypeIntegerUnsigned interface {
uint | uint8 | uint16 | uint32 | uint64
}
// GenericTypeFloat is the union of supported floating-point types.
type GenericTypeFloat interface {
float32 | float64
}
================================================
FILE: helpers_fuzz_test.go
================================================
//go:build go1.18
package fiber
import (
"testing"
)
// go test -v -run=^$ -fuzz=FuzzUtilsGetOffer
func FuzzUtilsGetOffer(f *testing.F) {
inputs := []string{
`application/json; v=1; foo=bar; q=0.938; extra=param, text/plain;param="big fox"; q=0.43`,
`text/html, application/xhtml+xml, application/xml;q=0.9, */*;q=0.8`,
`*/*`,
`text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c`,
}
for _, input := range inputs {
f.Add(input)
}
f.Fuzz(func(_ *testing.T, spec string) {
getOffer([]byte(spec), acceptsOfferType, `application/json;version=1;v=1;foo=bar`, `text/plain;param="big fox"`)
})
}
================================================
FILE: helpers_test.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 📝 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"bytes"
"context"
"crypto/tls"
"math"
"net"
"os"
"strconv"
"testing"
"time"
"unsafe"
"github.com/gofiber/utils/v2"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func Test_Utils_GetOffer(t *testing.T) {
t.Parallel()
require.Empty(t, getOffer([]byte("hello"), acceptsOffer))
require.Equal(t, "1", getOffer([]byte(""), acceptsOffer, "1"))
require.Empty(t, getOffer([]byte("2"), acceptsOffer, "1"))
require.Empty(t, getOffer([]byte(""), acceptsOfferType))
require.Empty(t, getOffer([]byte("text/html"), acceptsOfferType))
require.Empty(t, getOffer([]byte("text/html"), acceptsOfferType, "application/json"))
require.Empty(t, getOffer([]byte("text/html;q=0"), acceptsOfferType, "text/html"))
require.Empty(t, getOffer([]byte("application/json, */*; q=0"), acceptsOfferType, "image/png"))
require.Equal(t, "application/xml", getOffer([]byte("text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"), acceptsOfferType, "application/xml", "application/json"))
require.Equal(t, "text/html", getOffer([]byte("text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"), acceptsOfferType, "text/html"))
require.Equal(t, "application/pdf", getOffer([]byte("text/plain;q=0,application/pdf;q=0.9,*/*;q=0.000"), acceptsOfferType, "application/pdf", "application/json"))
require.Equal(t, "application/pdf", getOffer([]byte("text/plain;q=0,application/pdf;q=0.9,*/*;q=0.000"), acceptsOfferType, "application/pdf", "application/json"))
require.Equal(t, "text/plain;a=1", getOffer([]byte("text/plain;a=1"), acceptsOfferType, "text/plain;a=1"))
require.Empty(t, getOffer([]byte("text/plain;a=1;b=2"), acceptsOfferType, "text/plain;b=2"))
// Spaces, quotes, out of order params, and case insensitivity
require.Equal(t, "text/plain", getOffer([]byte("text/plain "), acceptsOfferType, "text/plain"))
require.Equal(t, "text/plain", getOffer([]byte("text/plain;q=0.4 "), acceptsOfferType, "text/plain"))
require.Equal(t, "text/plain", getOffer([]byte("text/plain;q=0.4 ;"), acceptsOfferType, "text/plain"))
require.Equal(t, "text/plain", getOffer([]byte("text/plain;q=0.4 ; p=foo"), acceptsOfferType, "text/plain"))
require.Equal(t, "text/plain;b=2;a=1", getOffer([]byte("text/plain ;a=1;b=2"), acceptsOfferType, "text/plain;b=2;a=1"))
require.Equal(t, "text/plain;a=1", getOffer([]byte("text/plain; a=1 "), acceptsOfferType, "text/plain;a=1"))
require.Equal(t, `text/plain;a="1;b=2\",text/plain"`, getOffer([]byte(`text/plain;a="1;b=2\",text/plain";q=0.9`), acceptsOfferType, `text/plain;a=1;b=2`, `text/plain;a="1;b=2\",text/plain"`))
require.Equal(t, "text/plain;A=CAPS", getOffer([]byte(`text/plain;a="caPs"`), acceptsOfferType, "text/plain;A=CAPS"))
// Priority
require.Equal(t, "text/plain", getOffer([]byte("text/plain"), acceptsOfferType, "text/plain", "text/plain;a=1"))
require.Equal(t, "text/plain;a=1", getOffer([]byte("text/plain"), acceptsOfferType, "text/plain;a=1", "", "text/plain"))
require.Equal(t, "text/plain;a=1", getOffer([]byte("text/plain,text/plain;a=1"), acceptsOfferType, "text/plain", "text/plain;a=1"))
require.Equal(t, "text/plain", getOffer([]byte("text/plain;q=0.899,text/plain;a=1;q=0.898"), acceptsOfferType, "text/plain", "text/plain;a=1"))
require.Equal(t, "text/plain;a=1;b=2", getOffer([]byte("text/plain,text/plain;a=1,text/plain;a=1;b=2"), acceptsOfferType, "text/plain", "text/plain;a=1", "text/plain;a=1;b=2"))
// Takes the last value specified
require.Equal(t, "text/plain;a=1;b=2", getOffer([]byte("text/plain;a=1;b=1;B=2"), acceptsOfferType, "text/plain;a=1;b=1", "text/plain;a=1;b=2"))
require.Empty(t, getOffer([]byte("utf-8, iso-8859-1;q=0.5"), acceptsOffer))
require.Empty(t, getOffer([]byte("utf-8, iso-8859-1;q=0.5"), acceptsOffer, "ascii"))
require.Equal(t, "utf-8", getOffer([]byte("utf-8, iso-8859-1;q=0.5"), acceptsOffer, "utf-8"))
require.Equal(t, "iso-8859-1", getOffer([]byte("utf-8;q=0, iso-8859-1;q=0.5"), acceptsOffer, "utf-8", "iso-8859-1"))
// Accept-Charset wildcard coverage
require.Equal(t, "utf-8", getOffer([]byte("utf-*"), acceptsOffer, "utf-8"))
require.Equal(t, "UTF-16", getOffer([]byte("utf-*"), acceptsOffer, "UTF-16", "iso-8859-1"))
require.Empty(t, getOffer([]byte("utf-*"), acceptsOffer, "iso-8859-1"))
require.Empty(t, getOffer([]byte("utf-*"), acceptsOffer, "utf"))
require.Empty(t, getOffer([]byte("utf-*"), acceptsOffer, "x-utf-8"))
// Complex wildcard negotiation
require.Equal(t, "utf-16le", getOffer([]byte("utf-8;q=0.4, utf-*;q=0.8, iso-8859-1;q=0.6"), acceptsOffer, "iso-8859-1", "utf-16le"))
require.Equal(t, "iso-8859-1", getOffer([]byte("utf-*;q=0.9, iso-8859-1;q=1"), acceptsOffer, "x-utf-16", "iso-8859-1"))
require.Empty(t, getOffer([]byte("utf-*;q=0.5, iso-8859-1;q=0.4"), acceptsOffer, "ascii", "us-ascii"))
require.Equal(t, "deflate", getOffer([]byte("gzip, deflate"), acceptsOffer, "deflate"))
require.Empty(t, getOffer([]byte("gzip, deflate;q=0"), acceptsOffer, "deflate"))
// Accept-Language Basic Filtering
require.True(t, acceptsLanguageOfferBasic("en", "en-US", nil))
require.False(t, acceptsLanguageOfferBasic("en-US", "en", nil))
require.True(t, acceptsLanguageOfferBasic("EN", "en-us", nil))
require.False(t, acceptsLanguageOfferBasic("en", "en_US", nil))
require.Equal(t, "en-US", getOffer([]byte("fr-CA;q=0.8, en-US"), acceptsLanguageOfferBasic, "en-US", "fr-CA"))
require.Empty(t, getOffer([]byte("xx"), acceptsLanguageOfferBasic, "en"))
require.False(t, acceptsLanguageOfferBasic("en-*", "en-US", nil))
require.True(t, acceptsLanguageOfferBasic("*", "en-US", nil))
// Accept-Language Extended Filtering
require.True(t, acceptsLanguageOfferExtended("en", "en-US", nil))
require.True(t, acceptsLanguageOfferExtended("en", "en-Latn-US", nil))
require.True(t, acceptsLanguageOfferExtended("en-*", "en-US", nil))
require.True(t, acceptsLanguageOfferExtended("*-US", "en-US", nil))
require.True(t, acceptsLanguageOfferExtended("en-US-*", "en-US", nil))
require.True(t, acceptsLanguageOfferExtended("en-*", "en-US-CA", nil))
require.False(t, acceptsLanguageOfferExtended("en-US", "en-GB", nil))
require.False(t, acceptsLanguageOfferExtended("fr", "en-US", nil))
require.False(t, acceptsLanguageOfferExtended("", "en-US", nil))
require.False(t, acceptsLanguageOfferExtended("en", "", nil))
require.True(t, acceptsLanguageOfferExtended("*", "en-US", nil))
require.True(t, acceptsLanguageOfferExtended("en-*", "en", nil))
require.Equal(t, "en-US", getOffer([]byte("fr-CA;q=0.8, en-*"), acceptsLanguageOfferExtended, "en-US", "fr-CA"))
// Sliding and singleton barriers
require.True(t, acceptsLanguageOfferExtended("de-*-DE", "de-DE", nil))
require.True(t, acceptsLanguageOfferExtended("de-*-DE", "de-DE-x-goethe", nil))
require.True(t, acceptsLanguageOfferExtended("de-*-DE", "de-Latn-DE-1996", nil))
require.False(t, acceptsLanguageOfferExtended("de-*-DE", "de", nil))
require.False(t, acceptsLanguageOfferExtended("de-*-DE", "de-x-DE", nil))
require.True(t, acceptsLanguageOfferExtended("*-CH", "de-CH", nil))
require.True(t, acceptsLanguageOfferExtended("*-CH", "de-Latn-CH", nil))
}
func Test_ReadContentReturnsBytes(t *testing.T) {
t.Parallel()
content := []byte("fiber read content test")
tempFile, err := os.CreateTemp("", "fiber-read-content-*.txt")
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, os.Remove(tempFile.Name()))
})
_, err = tempFile.Write(content)
require.NoError(t, err)
require.NoError(t, tempFile.Close())
var buffer bytes.Buffer
n, err := readContent(&buffer, tempFile.Name())
require.NoError(t, err)
require.Equal(t, int64(len(content)), n)
require.Equal(t, content, buffer.Bytes())
}
type wrappedListener struct {
net.Listener
}
type tlsConfigMethodListener struct {
net.Listener
cfg *tls.Config
}
func (ln *tlsConfigMethodListener) TLSConfig() *tls.Config {
return ln.cfg
}
type configMethodListener struct {
net.Listener
cfg *tls.Config
}
func (ln *configMethodListener) Config() *tls.Config {
return ln.cfg
}
func Test_GetTLSConfig(t *testing.T) {
t.Parallel()
t.Run("tls listener", func(t *testing.T) {
t.Parallel()
base := newLocalListener(t)
cfg := &tls.Config{MinVersion: tls.VersionTLS12}
tlsListener := tls.NewListener(base, cfg)
t.Cleanup(func() {
require.NoError(t, tlsListener.Close())
})
require.Same(t, cfg, getTLSConfig(tlsListener), "*tls.Listener should expose its TLS config")
})
t.Run("wrapped tls listener", func(t *testing.T) {
t.Parallel()
base := newLocalListener(t)
cfg := &tls.Config{MinVersion: tls.VersionTLS13}
tlsListener := tls.NewListener(base, cfg)
wrapped := &wrappedListener{Listener: tlsListener}
t.Cleanup(func() {
require.NoError(t, wrapped.Close())
})
require.Nil(t, getTLSConfig(wrapped), "wrapping without Config()-like methods should return nil")
})
t.Run("listener with tls config method", func(t *testing.T) {
t.Parallel()
base := newLocalListener(t)
cfg := &tls.Config{MinVersion: tls.VersionTLS13}
listener := &tlsConfigMethodListener{Listener: base, cfg: cfg}
t.Cleanup(func() {
require.NoError(t, listener.Close())
})
require.Same(t, cfg, getTLSConfig(listener), "TLSConfig() should be preferred for TLS discovery")
})
t.Run("listener with config method", func(t *testing.T) {
t.Parallel()
base := newLocalListener(t)
cfg := &tls.Config{MinVersion: tls.VersionTLS12}
listener := &configMethodListener{Listener: base, cfg: cfg}
t.Cleanup(func() {
require.NoError(t, listener.Close())
})
require.Same(t, cfg, getTLSConfig(listener), "Config() should be preferred for TLS discovery")
})
t.Run("non tls listener", func(t *testing.T) {
t.Parallel()
base := newLocalListener(t)
t.Cleanup(func() {
require.NoError(t, base.Close())
})
require.Nil(t, getTLSConfig(base), "plain listeners should not report TLS config")
})
}
func newLocalListener(t *testing.T) net.Listener {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
return ln
}
// go test -v -run=^$ -bench=Benchmark_Utils_GetOffer -benchmem -count=4
func Benchmark_Utils_GetOffer(b *testing.B) {
testCases := []struct {
description string
accept string
offers []string
}{
{
description: "simple",
accept: "application/json",
offers: []string{"application/json"},
},
{
description: "6 offers",
accept: "text/plain",
offers: []string{"junk/a", "junk/b", "junk/c", "junk/d", "junk/e", "text/plain"},
},
{
description: "1 parameter",
accept: "application/json; version=1",
offers: []string{"application/json;version=1"},
},
{
description: "2 parameters",
accept: "application/json; version=1; foo=bar",
offers: []string{"application/json;version=1;foo=bar"},
},
{
description: "3 parameters",
accept: "application/json; version=1; foo=bar; charset=utf-8",
offers: []string{"application/json;version=1;foo=bar;charset=utf-8"},
},
{
description: "10 parameters",
accept: "text/plain;a=1;b=2;c=3;d=4;e=5;f=6;g=7;h=8;i=9;j=10",
offers: []string{"text/plain;a=1;b=2;c=3;d=4;e=5;f=6;g=7;h=8;i=9;j=10"},
},
{
description: "6 offers w/params",
accept: "text/plain; format=flowed",
offers: []string{
"junk/a;a=b",
"junk/b;b=c",
"junk/c;c=d",
"text/plain; format=justified",
"text/plain; format=flat",
"text/plain; format=flowed",
},
},
{
description: "mime extension",
accept: "utf-8, iso-8859-1;q=0.5",
offers: []string{"utf-8"},
},
{
description: "mime extension",
accept: "utf-8, iso-8859-1;q=0.5",
offers: []string{"iso-8859-1"},
},
{
description: "mime extension",
accept: "utf-8, iso-8859-1;q=0.5",
offers: []string{"iso-8859-1", "utf-8"},
},
{
description: "mime extension",
accept: "gzip, deflate",
offers: []string{"deflate"},
},
{
description: "web browser",
accept: "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
offers: []string{"text/html", "application/xml", "application/xml+xhtml"},
},
}
b.ReportAllocs()
for _, tc := range testCases {
accept := []byte(tc.accept)
b.Run(tc.description, func(b *testing.B) {
for b.Loop() {
getOffer(accept, acceptsOfferType, tc.offers...)
}
})
}
}
func Test_Utils_ParamsMatch(t *testing.T) {
testCases := []struct {
description string
accept headerParams
offer string
match bool
}{
{
description: "empty accept and offer",
accept: nil,
offer: "",
match: true,
},
{
description: "accept is empty, offer has params",
accept: make(headerParams),
offer: ";foo=bar",
match: true,
},
{
description: "offer is empty, accept has params",
accept: headerParams{"foo": []byte("bar")},
offer: "",
match: false,
},
{
description: "accept has extra parameters",
accept: headerParams{"foo": []byte("bar"), "a": []byte("1")},
offer: ";foo=bar",
match: false,
},
{
description: "matches regardless of order",
accept: headerParams{"b": []byte("2"), "a": []byte("1")},
offer: ";b=2;a=1",
match: true,
},
{
description: "case-insensitive",
accept: headerParams{"ParaM": []byte("FoO")},
offer: ";pAram=foO",
match: true,
},
}
for _, tc := range testCases {
require.Equal(t, tc.match, paramsMatch(tc.accept, tc.offer), tc.description)
}
}
func Benchmark_Utils_ParamsMatch(b *testing.B) {
var match bool
specParams := headerParams{
"appLe": []byte("orange"),
"param": []byte("foo"),
}
b.ReportAllocs()
for b.Loop() {
match = paramsMatch(specParams, `;param=foo; apple=orange`)
}
require.True(b, match)
}
func Test_Utils_AcceptsOfferType(t *testing.T) {
t.Parallel()
testCases := []struct {
description string
spec string
specParams headerParams
offerType string
accepts bool
}{
{
description: "no params, matching",
spec: "application/json",
offerType: "application/json",
accepts: true,
},
{
description: "no params, mismatch",
spec: "application/json",
offerType: "application/xml",
accepts: false,
},
{
description: "mismatch with subtype prefix",
spec: "application/json",
offerType: "application/json+xml",
accepts: false,
},
{
description: "params match",
spec: "application/json",
specParams: headerParams{"format": []byte("foo"), "version": []byte("1")},
offerType: "application/json;version=1;format=foo;q=0.1",
accepts: true,
},
{
description: "spec has extra params",
spec: "text/html",
specParams: headerParams{"charset": []byte("utf-8")},
offerType: "text/html",
accepts: false,
},
{
description: "offer has extra params",
spec: "text/html",
offerType: "text/html;charset=utf-8",
accepts: true,
},
{
description: "ignores optional whitespace",
spec: "application/json",
specParams: headerParams{"format": []byte("foo"), "version": []byte("1")},
offerType: "application/json; version=1 ; format=foo ",
accepts: true,
},
{
description: "ignores optional whitespace",
spec: "application/json",
specParams: headerParams{"format": []byte("foo bar"), "version": []byte("1")},
offerType: `application/json;version="1";format="foo bar"`,
accepts: true,
},
}
for _, tc := range testCases {
accepts := acceptsOfferType(tc.spec, tc.offerType, tc.specParams)
require.Equal(t, tc.accepts, accepts, tc.description)
}
}
func Test_Utils_GetSplicedStrList(t *testing.T) {
t.Parallel()
testCases := []struct {
description string
headerValue string
expectedList []string
}{
{
description: "normal case",
headerValue: "gzip, deflate,br",
expectedList: []string{"gzip", "deflate", "br"},
},
{
description: "no matter the value",
headerValue: " gzip,deflate, br, zip",
expectedList: []string{"gzip", "deflate", "br", "zip"},
},
{
description: "comma with trailing spaces around values",
headerValue: "gzip , br",
expectedList: []string{"gzip", "br"},
},
{
description: "comma with tabbed whitespace",
headerValue: "gzip\t,br",
expectedList: []string{"gzip", "br"},
},
{
description: "headerValue is empty",
headerValue: "",
expectedList: nil,
},
{
description: "has a comma without element",
headerValue: "gzip,",
expectedList: []string{"gzip", ""},
},
{
description: "has a space between words",
headerValue: " foo bar, hello world",
expectedList: []string{"foo bar", "hello world"},
},
{
description: "single comma",
headerValue: ",",
expectedList: []string{"", ""},
},
{
description: "multiple comma",
headerValue: ",,",
expectedList: []string{"", "", ""},
},
{
description: "comma with space",
headerValue: ", ,",
expectedList: []string{"", "", ""},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
tc := tc // create a new 'tc' variable for the goroutine
t.Parallel()
dst := make([]string, 10)
result := getSplicedStrList(tc.headerValue, dst)
require.Equal(t, tc.expectedList, result)
})
}
}
func Benchmark_Utils_GetSplicedStrList(b *testing.B) {
destination := make([]string, 5)
result := destination
const input = `deflate, gzip,br,brotli,zstd`
b.ReportAllocs()
for b.Loop() {
result = getSplicedStrList(input, destination)
}
require.Equal(b, []string{"deflate", "gzip", "br", "brotli", "zstd"}, result)
}
func Test_Utils_SortAcceptedTypes(t *testing.T) {
t.Parallel()
acceptedTypes := []acceptedType{
{spec: "text/html", quality: 1, specificity: 3, order: 0},
{spec: "text/*", quality: 0.5, specificity: 2, order: 1},
{spec: "*/*", quality: 0.1, specificity: 1, order: 2},
{spec: "application/xml", quality: 1, specificity: 3, order: 4},
{spec: "application/pdf", quality: 1, specificity: 3, order: 5},
{spec: "image/png", quality: 1, specificity: 3, order: 6},
{spec: "image/jpeg", quality: 1, specificity: 3, order: 7},
{spec: "image/*", quality: 1, specificity: 2, order: 8},
{spec: "image/gif", quality: 1, specificity: 3, order: 9},
{spec: "text/plain", quality: 1, specificity: 3, order: 10},
{spec: "application/json", quality: 0.999, specificity: 3, params: headerParams{"a": []byte("1")}, order: 11},
{spec: "application/json", quality: 0.999, specificity: 3, order: 3},
}
sortAcceptedTypes(acceptedTypes)
require.Equal(t, []acceptedType{
{spec: "text/html", quality: 1, specificity: 3, order: 0},
{spec: "application/xml", quality: 1, specificity: 3, order: 4},
{spec: "application/pdf", quality: 1, specificity: 3, order: 5},
{spec: "image/png", quality: 1, specificity: 3, order: 6},
{spec: "image/jpeg", quality: 1, specificity: 3, order: 7},
{spec: "image/gif", quality: 1, specificity: 3, order: 9},
{spec: "text/plain", quality: 1, specificity: 3, order: 10},
{spec: "image/*", quality: 1, specificity: 2, order: 8},
{spec: "application/json", quality: 0.999, specificity: 3, params: headerParams{"a": []byte("1")}, order: 11},
{spec: "application/json", quality: 0.999, specificity: 3, order: 3},
{spec: "text/*", quality: 0.5, specificity: 2, order: 1},
{spec: "*/*", quality: 0.1, specificity: 1, order: 2},
}, acceptedTypes)
}
// go test -v -run=^$ -bench=Benchmark_Utils_SortAcceptedTypes_Sorted -benchmem -count=4
func Benchmark_Utils_SortAcceptedTypes_Sorted(b *testing.B) {
acceptedTypes := make([]acceptedType, 3)
b.ReportAllocs()
for b.Loop() {
acceptedTypes[0] = acceptedType{spec: "text/html", quality: 1, specificity: 1, order: 0}
acceptedTypes[1] = acceptedType{spec: "text/*", quality: 0.5, specificity: 1, order: 1}
acceptedTypes[2] = acceptedType{spec: "*/*", quality: 0.1, specificity: 1, order: 2}
sortAcceptedTypes(acceptedTypes)
}
require.Equal(b, "text/html", acceptedTypes[0].spec)
require.Equal(b, "text/*", acceptedTypes[1].spec)
require.Equal(b, "*/*", acceptedTypes[2].spec)
}
// go test -v -run=^$ -bench=Benchmark_Utils_SortAcceptedTypes_Unsorted -benchmem -count=4
func Benchmark_Utils_SortAcceptedTypes_Unsorted(b *testing.B) {
acceptedTypes := make([]acceptedType, 11)
b.ReportAllocs()
for b.Loop() {
acceptedTypes[0] = acceptedType{spec: "text/html", quality: 1, specificity: 3, order: 0}
acceptedTypes[1] = acceptedType{spec: "text/*", quality: 0.5, specificity: 2, order: 1}
acceptedTypes[2] = acceptedType{spec: "*/*", quality: 0.1, specificity: 1, order: 2}
acceptedTypes[3] = acceptedType{spec: "application/json", quality: 0.999, specificity: 3, order: 3}
acceptedTypes[4] = acceptedType{spec: "application/xml", quality: 1, specificity: 3, order: 4}
acceptedTypes[5] = acceptedType{spec: "application/pdf", quality: 1, specificity: 3, order: 5}
acceptedTypes[6] = acceptedType{spec: "image/png", quality: 1, specificity: 3, order: 6}
acceptedTypes[7] = acceptedType{spec: "image/jpeg", quality: 1, specificity: 3, order: 7}
acceptedTypes[8] = acceptedType{spec: "image/*", quality: 1, specificity: 2, order: 8}
acceptedTypes[9] = acceptedType{spec: "image/gif", quality: 1, specificity: 3, order: 9}
acceptedTypes[10] = acceptedType{spec: "text/plain", quality: 1, specificity: 3, order: 10}
sortAcceptedTypes(acceptedTypes)
}
require.Equal(b, []acceptedType{
{spec: "text/html", quality: 1, specificity: 3, order: 0},
{spec: "application/xml", quality: 1, specificity: 3, order: 4},
{spec: "application/pdf", quality: 1, specificity: 3, order: 5},
{spec: "image/png", quality: 1, specificity: 3, order: 6},
{spec: "image/jpeg", quality: 1, specificity: 3, order: 7},
{spec: "image/gif", quality: 1, specificity: 3, order: 9},
{spec: "text/plain", quality: 1, specificity: 3, order: 10},
{spec: "image/*", quality: 1, specificity: 2, order: 8},
{spec: "application/json", quality: 0.999, specificity: 3, order: 3},
{spec: "text/*", quality: 0.5, specificity: 2, order: 1},
{spec: "*/*", quality: 0.1, specificity: 1, order: 2},
}, acceptedTypes)
}
func Test_Utils_UniqueRouteStack(t *testing.T) {
t.Parallel()
route1 := &Route{}
route2 := &Route{}
route3 := &Route{}
require.Equal(
t,
[]*Route{
route1,
route2,
route3,
},
uniqueRouteStack([]*Route{
route1,
route1,
route1,
route2,
route2,
route2,
route3,
route3,
route3,
route1,
route2,
route3,
}))
}
func Test_Utils_getGroupPath(t *testing.T) {
t.Parallel()
res := getGroupPath("/v1", "/")
require.Equal(t, "/v1/", res)
res = getGroupPath("/v1/", "/")
require.Equal(t, "/v1/", res)
res = getGroupPath("/", "/")
require.Equal(t, "/", res)
res = getGroupPath("/v1/api/", "/")
require.Equal(t, "/v1/api/", res)
res = getGroupPath("/v1/api", "group")
require.Equal(t, "/v1/api/group", res)
res = getGroupPath("/v1/api", "")
require.Equal(t, "/v1/api", res)
}
// go test -v -run=^$ -bench=Benchmark_Utils_ -benchmem -count=3
func Benchmark_Utils_getGroupPath(b *testing.B) {
var res string
b.ReportAllocs()
for b.Loop() {
_ = getGroupPath("/v1/long/path/john/doe", "/why/this/name/is/so/awesome")
_ = getGroupPath("/v1", "/")
_ = getGroupPath("/v1", "/api")
res = getGroupPath("/v1", "/api/register/:project")
}
require.Equal(b, "/v1/api/register/:project", res)
}
func Benchmark_Utils_Unescape(b *testing.B) {
unescaped := ""
dst := make([]byte, 0)
b.ReportAllocs()
for b.Loop() {
source := "/cr%C3%A9er"
pathBytes := utils.UnsafeBytes(source)
pathBytes = fasthttp.AppendUnquotedArg(dst[:0], pathBytes)
unescaped = utils.UnsafeString(pathBytes)
}
require.Equal(b, "/créer", unescaped)
}
func Test_Utils_Parse_Address(t *testing.T) {
t.Parallel()
testCases := []struct {
addr, host, port string
}{
{addr: "[::1]:3000", host: "[::1]", port: "3000"},
{addr: "127.0.0.1:3000", host: "127.0.0.1", port: "3000"},
{addr: "[::1]", host: "[::1]", port: ""},
{addr: "2001:db8::1", host: "2001:db8::1", port: ""},
{addr: "/path/to/unix/socket", host: "/path/to/unix/socket", port: ""},
{addr: "127.0.0.1", host: "127.0.0.1", port: ""},
{addr: "localhost:8080", host: "localhost", port: "8080"},
{addr: "example.com", host: "example.com", port: ""},
{addr: "[fe80::1%lo0]:1234", host: "[fe80::1%lo0]", port: "1234"},
{addr: "[fe80::1%lo0]", host: "[fe80::1%lo0]", port: ""},
{addr: ":9090", host: "", port: "9090"},
{addr: " 127.0.0.1:8080 ", host: "127.0.0.1", port: "8080"},
{addr: "", host: "", port: ""},
}
for _, c := range testCases {
host, port := parseAddr(c.addr)
require.Equal(t, c.host, host, "addr host: %q", c.addr)
require.Equal(t, c.port, port, "addr port: %q", c.addr)
}
}
func Test_Utils_TestConn_Deadline(t *testing.T) {
t.Parallel()
conn := &testConn{}
require.NoError(t, conn.SetDeadline(time.Time{}))
require.NoError(t, conn.SetReadDeadline(time.Time{}))
require.NoError(t, conn.SetWriteDeadline(time.Time{}))
}
func Test_Utils_TestConn_ReadWrite(t *testing.T) {
t.Parallel()
conn := &testConn{}
// Verify read of request
_, err := conn.r.Write([]byte("Request"))
require.NoError(t, err)
req := make([]byte, 7)
_, err = conn.Read(req)
require.NoError(t, err)
require.Equal(t, []byte("Request"), req)
// Verify write of response
_, err = conn.Write([]byte("Response"))
require.NoError(t, err)
res := make([]byte, 8)
_, err = conn.w.Read(res)
require.NoError(t, err)
require.Equal(t, []byte("Response"), res)
}
func Test_Utils_TestConn_Closed_Write(t *testing.T) {
t.Parallel()
conn := &testConn{}
// Verify write of response
_, err := conn.Write([]byte("Response 1\n"))
require.NoError(t, err)
// Close early, write should fail
conn.Close() //nolint:errcheck // It is fine to ignore the error here
_, err = conn.Write([]byte("Response 2\n"))
require.ErrorIs(t, err, errTestConnClosed)
res := make([]byte, 11)
_, err = conn.w.Read(res)
require.NoError(t, err)
require.Equal(t, []byte("Response 1\n"), res)
}
func Test_Utils_IsNoCache(t *testing.T) {
t.Parallel()
testCases := []struct {
string
bool
}{
{string: "public", bool: false},
{string: "no-cache", bool: true},
{string: "public, no-cache, max-age=30", bool: true},
{string: "public,no-cache", bool: true},
{string: "public,no-cacheX", bool: false},
{string: "no-cache, public", bool: true},
{string: "Xno-cache, public", bool: false},
{string: "max-age=30, no-cache,public", bool: true},
{string: "NO-CACHE", bool: true},
{string: "public, NO-CACHE", bool: true},
// RFC 9111 §5.2.2.4: no-cache with field-name argument
{string: "no-cache=\"Set-Cookie\"", bool: true},
{string: "public, no-cache=\"Set-Cookie, Set-Cookie2\"", bool: true},
{string: "no-cache=Set-Cookie", bool: true},
// Edge cases with spaces
{string: "no-cache ,public", bool: true},
{string: "public, no-cache =field", bool: true},
}
for _, c := range testCases {
ok := isNoCache(c.string)
require.Equal(t, c.bool, ok, "want %t, got isNoCache(%s)=%t", c.bool, c.string, ok)
}
}
// go test -v -run=^$ -bench=Benchmark_Utils_IsNoCache -benchmem -count=4
func Benchmark_Utils_IsNoCache(b *testing.B) {
var ok bool
b.ReportAllocs()
for b.Loop() {
_ = isNoCache("public")
_ = isNoCache("no-cache")
_ = isNoCache("public, no-cache, max-age=30")
_ = isNoCache("public,no-cache")
_ = isNoCache("no-cache, public")
ok = isNoCache("max-age=30, no-cache,public")
}
require.True(b, ok)
}
// go test -run Test_HeaderContainsValue
func Test_HeaderContainsValue(t *testing.T) {
t.Parallel()
testCases := []struct {
header string
value string
expected bool
}{
// Exact match
{header: "gzip", value: "gzip", expected: true},
{header: "gzip", value: "deflate", expected: false},
// Prefix match (value at start with comma)
{header: "gzip, deflate", value: "gzip", expected: true},
{header: "gzip,deflate", value: "gzip", expected: true},
// Suffix match (value at end)
{header: "deflate, gzip", value: "gzip", expected: true},
{header: "deflate,gzip", value: "gzip", expected: true}, // No space - OWS is optional per RFC 9110
{header: "br, gzip", value: "gzip", expected: true},
// Middle match (value in middle)
{header: "deflate, gzip, br", value: "gzip", expected: true},
{header: "deflate,gzip,br", value: "gzip", expected: true}, // No spaces - OWS is optional per RFC 9110
// No match - similar but not equal
{header: "gzip2", value: "gzip", expected: false},
{header: "2gzip", value: "gzip", expected: false},
{header: "gzip2, deflate", value: "gzip", expected: false},
// Whitespace handling (OWS per RFC 9110)
{header: " gzip , deflate ", value: "gzip", expected: true},
{header: "deflate, gzip ", value: "gzip", expected: true},
// Empty cases
{header: "", value: "gzip", expected: false},
{header: "gzip", value: "", expected: false},
{header: "", value: "", expected: false}, // Both empty - should return false
}
for _, tc := range testCases {
result := headerContainsValue(tc.header, tc.value)
require.Equal(t, tc.expected, result,
"headerContainsValue(%q, %q) = %v, want %v",
tc.header, tc.value, result, tc.expected)
}
}
// go test -v -run=^$ -bench=Benchmark_HeaderContainsValue -benchmem -count=4
func Benchmark_HeaderContainsValue(b *testing.B) {
var ok bool
b.ReportAllocs()
for b.Loop() {
_ = headerContainsValue("gzip", "gzip")
_ = headerContainsValue("gzip, deflate, br", "deflate")
_ = headerContainsValue("deflate, gzip", "gzip")
ok = headerContainsValue("deflate, gzip, br", "gzip")
}
require.True(b, ok)
}
type testGenericParseTypeIntCase struct {
value int64
bits int
}
// go test -run Test_GenericParseTypeInts
func Test_GenericParseTypeInts(t *testing.T) {
t.Parallel()
ints := []testGenericParseTypeIntCase{
{
value: 0,
bits: 8,
},
{
value: 1,
bits: 8,
},
{
value: 2,
bits: 8,
},
{
value: 3,
bits: 8,
},
{
value: 4,
bits: 8,
},
{
value: -1,
bits: 8,
},
{
value: math.MaxInt8,
bits: 8,
},
{
value: math.MinInt8,
bits: 8,
},
{
value: math.MaxInt16,
bits: 16,
},
{
value: math.MinInt16,
bits: 16,
},
{
value: math.MaxInt32,
bits: 32,
},
{
value: math.MinInt32,
bits: 32,
},
{
value: math.MaxInt64,
bits: 64,
},
{
value: math.MinInt64,
bits: 64,
},
}
testGenericTypeInt[int8](t, "test_genericParseTypeInt8s", ints)
testGenericTypeInt[int16](t, "test_genericParseTypeInt16s", ints)
testGenericTypeInt[int32](t, "test_genericParseTypeInt32s", ints)
testGenericTypeInt[int64](t, "test_genericParseTypeInt64s", ints)
testGenericTypeInt[int](t, "test_genericParseTypeInts", ints)
}
func testGenericTypeInt[V GenericTypeInteger](t *testing.T, name string, cases []testGenericParseTypeIntCase) {
t.Helper()
t.Run(name, func(t *testing.T) {
t.Parallel()
for _, test := range cases {
v, err := genericParseType[V](strconv.FormatInt(test.value, 10))
if test.bits <= int(unsafe.Sizeof(V(0)))*8 {
require.NoError(t, err)
require.Equal(t, V(test.value), v)
} else {
require.ErrorIs(t, err, strconv.ErrRange)
}
}
testGenericParseError[V](t)
})
}
type testGenericParseTypeUintCase struct {
value uint64
bits int
}
// go test -run Test_GenericParseTypeUints
func Test_GenericParseTypeUints(t *testing.T) {
t.Parallel()
uints := []testGenericParseTypeUintCase{
{
value: 0,
bits: 8,
},
{
value: 1,
bits: 8,
},
{
value: 2,
bits: 8,
},
{
value: 3,
bits: 8,
},
{
value: 4,
bits: 8,
},
{
value: math.MaxUint8,
bits: 8,
},
{
value: math.MaxUint16,
bits: 16,
},
{
value: math.MaxUint32,
bits: 32,
},
{
value: math.MaxUint64,
bits: 64,
},
}
testGenericTypeUint[uint8](t, "test_genericParseTypeUint8s", uints)
testGenericTypeUint[uint16](t, "test_genericParseTypeUint16s", uints)
testGenericTypeUint[uint32](t, "test_genericParseTypeUint32s", uints)
testGenericTypeUint[uint64](t, "test_genericParseTypeUint64s", uints)
testGenericTypeUint[uint](t, "test_genericParseTypeUints", uints)
}
func testGenericTypeUint[V GenericTypeInteger](t *testing.T, name string, cases []testGenericParseTypeUintCase) {
t.Helper()
t.Run(name, func(t *testing.T) {
t.Parallel()
for _, test := range cases {
v, err := genericParseType[V](strconv.FormatUint(test.value, 10))
if test.bits <= int(unsafe.Sizeof(V(0)))*8 {
require.NoError(t, err)
require.Equal(t, V(test.value), v)
} else {
require.ErrorIs(t, err, strconv.ErrRange)
}
}
testGenericParseError[V](t)
})
}
// go test -run Test_GenericParseTypeFloats
func Test_GenericParseTypeFloats(t *testing.T) {
t.Parallel()
floats := []struct {
str string
value float64
}{
{
value: 3.1415,
str: "3.1415",
},
{
value: 1.234,
str: "1.234",
},
{
value: 2,
str: "2",
},
{
value: 3,
str: "3",
},
}
t.Run("test_genericParseTypeFloat32s", func(t *testing.T) {
t.Parallel()
for _, test := range floats {
v, err := genericParseType[float32](test.str)
require.NoError(t, err)
require.InEpsilon(t, float32(test.value), v, epsilon)
}
testGenericParseError[float32](t)
})
t.Run("test_genericParseTypeFloat64s", func(t *testing.T) {
t.Parallel()
for _, test := range floats {
v, err := genericParseType[float64](test.str)
require.NoError(t, err)
require.InEpsilon(t, test.value, v, epsilon)
}
testGenericParseError[float64](t)
})
}
// go test -run Test_GenericParseTypeBytes
func Test_GenericParseTypeBytes(t *testing.T) {
t.Parallel()
cases := []struct {
str string
err error
value []byte
}{
{
value: []byte("alex"),
str: "alex",
},
{
value: []byte("32.23"),
str: "32.23",
},
{
value: []byte("john"),
str: "john",
},
{
value: []byte(nil),
str: "",
err: errParsedEmptyBytes,
},
}
t.Run("test_genericParseTypeBytes", func(t *testing.T) {
t.Parallel()
for _, test := range cases {
v, err := genericParseType[[]byte](test.str)
if test.err == nil {
require.NoError(t, err)
} else {
require.ErrorIs(t, err, test.err)
}
require.Equal(t, test.value, v)
}
})
}
// go test -run Test_GenericParseTypeString
func Test_GenericParseTypeString(t *testing.T) {
t.Parallel()
tests := []string{"john", "doe", "hello", "fiber"}
for _, test := range tests {
t.Run("test_genericParseTypeString", func(t *testing.T) {
t.Parallel()
v, err := genericParseType[string](test)
require.NoError(t, err)
require.Equal(t, test, v)
})
}
}
// go test -run Test_GenericParseTypeBoolean
func Test_GenericParseTypeBoolean(t *testing.T) {
t.Parallel()
bools := []struct {
str string
value bool
}{
{
str: "True",
value: true,
},
{
str: "False",
value: false,
},
{
str: "true",
value: true,
},
{
str: "false",
value: false,
},
}
t.Run("test_genericParseTypeBoolean", func(t *testing.T) {
t.Parallel()
for _, test := range bools {
v, err := genericParseType[bool](test.str)
require.NoError(t, err)
if test.value {
require.True(t, v)
} else {
require.False(t, v)
}
}
testGenericParseError[bool](t)
})
}
func testGenericParseError[V GenericType](t *testing.T) {
t.Helper()
var expected V
v, err := genericParseType[V]("invalid-string")
require.Error(t, err)
require.Equal(t, expected, v)
}
// go test -v -run=^$ -bench=Benchmark_GenericParseTypeInts -benchmem -count=4
func Benchmark_GenericParseTypeInts(b *testing.B) {
b.Skip("Skipped: too fast to compare reliably (results in sub-ns range are unstable)")
ints := []testGenericParseTypeIntCase{
{
value: 0,
bits: 8,
},
{
value: 1,
bits: 8,
},
{
value: 2,
bits: 8,
},
{
value: 3,
bits: 8,
},
{
value: 4,
bits: 8,
},
{
value: -1,
bits: 8,
},
{
value: math.MaxInt8,
bits: 8,
},
{
value: math.MinInt8,
bits: 8,
},
{
value: math.MaxInt16,
bits: 16,
},
{
value: math.MinInt16,
bits: 16,
},
{
value: math.MaxInt32,
bits: 32,
},
{
value: math.MinInt32,
bits: 32,
},
{
value: math.MaxInt64,
bits: 64,
},
{
value: math.MinInt64,
bits: 64,
},
}
for _, test := range ints {
benchGenericParseTypeInt[int8](b, "bench_genericParseTypeInt8s", test)
benchGenericParseTypeInt[int16](b, "bench_genericParseTypeInt16s", test)
benchGenericParseTypeInt[int32](b, "bench_genericParseTypeInt32s", test)
benchGenericParseTypeInt[int64](b, "bench_genericParseTypeInt64s", test)
benchGenericParseTypeInt[int](b, "bench_genericParseTypeInts", test)
}
}
func benchGenericParseTypeInt[V GenericTypeInteger](b *testing.B, name string, test testGenericParseTypeIntCase) {
b.Helper()
b.Run(name, func(t *testing.B) {
var v V
var err error
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
v, err = genericParseType[V](strconv.FormatInt(test.value, 10))
}
})
if test.bits <= int(unsafe.Sizeof(V(0)))*8 {
require.NoError(t, err)
require.Equal(t, V(test.value), v)
} else {
require.ErrorIs(t, err, strconv.ErrRange)
}
})
}
// go test -v -run=^$ -bench=Benchmark_GenericParseTypeUints -benchmem -count=4
func Benchmark_GenericParseTypeUints(b *testing.B) {
b.Skip("Skipped: too fast to compare reliably (results in sub-ns range are unstable)")
uints := []struct {
value uint64
bits int
}{
{
value: 0,
bits: 8,
},
{
value: 1,
bits: 8,
},
{
value: 2,
bits: 8,
},
{
value: 3,
bits: 8,
},
{
value: 4,
bits: 8,
},
{
value: math.MaxUint8,
bits: 8,
},
{
value: math.MaxUint16,
bits: 16,
},
{
value: math.MaxUint16,
bits: 16,
},
{
value: math.MaxUint32,
bits: 32,
},
{
value: math.MaxUint64,
bits: 64,
},
}
for _, test := range uints {
benchGenericParseTypeUInt[uint8](b, "benchmark_genericParseTypeUint8s", test)
benchGenericParseTypeUInt[uint16](b, "benchmark_genericParseTypeUint16s", test)
benchGenericParseTypeUInt[uint32](b, "benchmark_genericParseTypeUint32s", test)
benchGenericParseTypeUInt[uint64](b, "benchmark_genericParseTypeUint64s", test)
benchGenericParseTypeUInt[uint](b, "benchmark_genericParseTypeUints", test)
}
}
func benchGenericParseTypeUInt[V GenericTypeInteger](b *testing.B, name string, test testGenericParseTypeUintCase) {
b.Helper()
b.Run(name, func(t *testing.B) {
var v V
var err error
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
v, err = genericParseType[V](strconv.FormatUint(test.value, 10))
}
})
if test.bits <= int(unsafe.Sizeof(V(0)))*8 {
require.NoError(t, err)
require.Equal(t, V(test.value), v)
} else {
require.ErrorIs(t, err, strconv.ErrRange)
}
})
}
// go test -v -run=^$ -bench=Benchmark_GenericParseTypeFloats -benchmem -count=4
func Benchmark_GenericParseTypeFloats(b *testing.B) {
b.Skip("Skipped: too fast to compare reliably (results in sub-ns range are unstable)")
floats := []struct {
str string
value float64
}{
{
value: 3.1415,
str: "3.1415",
},
{
value: 1.234,
str: "1.234",
},
{
value: 2,
str: "2",
},
{
value: 3,
str: "3",
},
}
for _, test := range floats {
b.Run("benchmark_genericParseTypeFloat32s", func(t *testing.B) {
var v float32
var err error
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
v, err = genericParseType[float32](test.str)
}
})
require.NoError(t, err)
require.InEpsilon(t, float32(test.value), v, epsilon)
})
}
for _, test := range floats {
b.Run("benchmark_genericParseTypeFloat64s", func(t *testing.B) {
var v float64
var err error
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
v, err = genericParseType[float64](test.str)
}
})
require.NoError(t, err)
require.InEpsilon(t, test.value, v, epsilon)
})
}
}
// go test -v -run=^$ -bench=Benchmark_GenericParseTypeBytes -benchmem -count=4
func Benchmark_GenericParseTypeBytes(b *testing.B) {
b.Skip("Skipped: too fast to compare reliably (results in sub-ns range are unstable)")
cases := []struct {
str string
err error
value []byte
}{
{
value: []byte("alex"),
str: "alex",
},
{
value: []byte("32.23"),
str: "32.23",
},
{
value: []byte("john"),
str: "john",
},
{
value: []byte(nil),
str: "",
err: errParsedEmptyBytes,
},
}
for _, test := range cases {
b.Run("benchmark_genericParseTypeBytes", func(b *testing.B) {
var v []byte
var err error
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
v, err = genericParseType[[]byte](test.str)
}
})
if test.err == nil {
require.NoError(b, err)
} else {
require.ErrorIs(b, err, test.err)
}
require.Equal(b, test.value, v)
})
}
}
// go test -v -run=^$ -bench=Benchmark_GenericParseTypeString -benchmem -count=4
func Benchmark_GenericParseTypeString(b *testing.B) {
b.Skip("Skipped: too fast to compare reliably (results in sub-ns range are unstable)")
tests := []string{"john", "doe", "hello", "fiber"}
for _, test := range tests {
b.Run("benchmark_genericParseTypeString", func(b *testing.B) {
var v string
var err error
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
v, err = genericParseType[string](test)
}
})
require.NoError(b, err)
require.Equal(b, test, v)
})
}
}
// go test -v -run=^$ -bench=Benchmark_GenericParseTypeBoolean -benchmem -count=4
func Benchmark_GenericParseTypeBoolean(b *testing.B) {
b.Skip("Skipped: too fast to compare reliably (results in sub-ns range are unstable)")
bools := []struct {
str string
value bool
}{
{
str: "True",
value: true,
},
{
str: "False",
value: false,
},
{
str: "true",
value: true,
},
{
str: "false",
value: false,
},
}
for _, test := range bools {
b.Run("benchmark_genericParseTypeBoolean", func(b *testing.B) {
var v bool
var err error
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
v, err = genericParseType[bool](test.str)
}
})
require.NoError(b, err)
if test.value {
require.True(b, v)
} else {
require.False(b, v)
}
})
}
}
func Test_UnescapeHeaderValue(t *testing.T) {
t.Parallel()
cases := []struct {
in string
out []byte
ok bool
}{
{in: "abc", out: []byte("abc"), ok: true},
{in: "a\\\"b", out: []byte("a\"b"), ok: true},
{in: "c\\\\d", out: []byte("c\\d"), ok: true},
{in: "bad\\", ok: false},
}
for _, tc := range cases {
out, err := unescapeHeaderValue([]byte(tc.in))
if tc.ok {
require.NoError(t, err, tc.in)
require.Equal(t, tc.out, out, tc.in)
} else {
require.Error(t, err, tc.in)
}
}
}
func Test_JoinHeaderValues(t *testing.T) {
t.Parallel()
require.Nil(t, joinHeaderValues(nil))
require.Equal(t, []byte("a"), joinHeaderValues([][]byte{[]byte("a")}))
require.Equal(t, []byte("a,b"), joinHeaderValues([][]byte{[]byte("a"), []byte("b")}))
}
func Test_ParamsMatch_InvalidEscape(t *testing.T) {
t.Parallel()
match := paramsMatch(headerParams{"foo": []byte("bar")}, `;foo="bar\\`)
require.False(t, match)
}
func Test_MatchEtag(t *testing.T) {
t.Parallel()
require.True(t, matchEtag(`"a"`, `"a"`))
require.True(t, matchEtag(`W/"a"`, `"a"`))
require.True(t, matchEtag(`"a"`, `W/"a"`))
require.False(t, matchEtag(`"a"`, `"b"`))
require.False(t, matchEtag(`a`, `"a"`))
require.False(t, matchEtag(`"a"`, `b`))
}
func Test_MatchEtagStrong(t *testing.T) {
t.Parallel()
require.True(t, matchEtagStrong(`"a"`, `"a"`))
require.False(t, matchEtagStrong(`W/"a"`, `"a"`))
require.False(t, matchEtagStrong(`"a"`, `W/"a"`))
require.False(t, matchEtagStrong(`"a"`, `"b"`))
require.False(t, matchEtagStrong(`a`, `"a"`))
require.False(t, matchEtagStrong(`"a"`, `b`))
}
func Test_IsEtagStale(t *testing.T) {
t.Parallel()
app := New()
// Invalid/unquoted tags are considered a mismatch, so it's stale
require.True(t, app.isEtagStale(`"a"`, []byte("b")))
require.True(t, app.isEtagStale(`"a"`, []byte("a")))
// Matching tags, not stale
require.False(t, app.isEtagStale(`"a"`, []byte(`"a"`)))
require.False(t, app.isEtagStale(`W/"a"`, []byte(`"a"`)))
// List of tags, not stale
require.False(t, app.isEtagStale(`"c"`, []byte(`"a", "b", "c"`)))
require.False(t, app.isEtagStale(`W/"c"`, []byte(`"a", "b", "c"`)))
require.False(t, app.isEtagStale(`"c"`, []byte(`"a", "b", W/"c"`)))
require.False(t, app.isEtagStale(`"c"`, []byte(`"c", "b", "a"`)))
require.False(t, app.isEtagStale(`"c"`, []byte(` "a", "c" , "b" `)))
// List of tags, stale
require.True(t, app.isEtagStale(`"d"`, []byte(`"a", "b", "c"`)))
require.True(t, app.isEtagStale(`W/"d"`, []byte(`"a", "b", "c"`)))
// Wildcard
require.False(t, app.isEtagStale(`"a"`, []byte("*")))
require.False(t, app.isEtagStale(`"a"`, []byte(" * ")))
require.False(t, app.isEtagStale(`W/"a"`, []byte("*")))
// Empty case
require.True(t, app.isEtagStale(`"a"`, []byte("")))
require.True(t, app.isEtagStale(`"a"`, []byte(" ")))
// Weak vs. weak
require.False(t, app.isEtagStale(`W/"a"`, []byte(`W/"a"`)))
}
func Test_App_quoteRawString(t *testing.T) {
t.Parallel()
cases := []struct {
name string
in string
out string
}{
{"empty", "", ""},
{"simple", "simple", "simple"},
{"backslash", "A\\B", "A\\\\B"},
{"quote", `He said "Yo"`, `He said \"Yo\"`},
{"newline", "Hello\n", "Hello\\n"},
{"carriage", "Hello\r", "Hello\\r"},
{"controls", string([]byte{0, 31, 127}), "%00%1F%7F"},
{"mixed", "test \"A\n\r" + string([]byte{1}) + "\\", `test \"A\n\r%01\\`},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := New()
require.Equal(t, tc.out, app.quoteRawString(tc.in))
})
}
}
func TestStoreInContext(t *testing.T) {
t.Parallel()
app := New(Config{PassLocalsToContext: true})
raw := &fasthttp.RequestCtx{}
c := app.AcquireCtx(raw)
defer app.ReleaseCtx(c)
StoreInContext(c, "key", "value")
localValue, ok := c.Locals("key").(string)
require.True(t, ok)
require.Equal(t, "value", localValue)
contextValue, ok := c.Context().Value("key").(string)
require.True(t, ok)
require.Equal(t, "value", contextValue)
}
func TestValueFromContext(t *testing.T) {
t.Parallel()
t.Run("fiber.Ctx", func(t *testing.T) {
t.Parallel()
app := New()
raw := &fasthttp.RequestCtx{}
c := app.AcquireCtx(raw)
defer app.ReleaseCtx(c)
c.Locals("key", "value")
value, ok := ValueFromContext[string](c, "key")
require.True(t, ok)
require.Equal(t, "value", value)
})
t.Run("fiber.CustomCtx", func(t *testing.T) {
t.Parallel()
app := NewWithCustomCtx(func(app *App) CustomCtx {
return &customCtx{DefaultCtx: *NewDefaultCtx(app)}
})
raw := &fasthttp.RequestCtx{}
c := app.AcquireCtx(raw)
defer app.ReleaseCtx(c)
c.Locals("key", "value")
value, ok := ValueFromContext[string](c, "key")
require.True(t, ok)
require.Equal(t, "value", value)
})
t.Run("fasthttp request ctx", func(t *testing.T) {
t.Parallel()
raw := &fasthttp.RequestCtx{}
raw.SetUserValue("key", "value")
value, ok := ValueFromContext[string](raw, "key")
require.True(t, ok)
require.Equal(t, "value", value)
})
t.Run("context.Context", func(t *testing.T) {
t.Parallel()
type testContextKey struct{}
ctx := context.WithValue(context.Background(), testContextKey{}, "value")
value, ok := ValueFromContext[string](ctx, testContextKey{})
require.True(t, ok)
require.Equal(t, "value", value)
})
t.Run("unsupported ctx", func(t *testing.T) {
t.Parallel()
value, ok := ValueFromContext[string](42, "key")
require.False(t, ok)
require.Empty(t, value)
})
}
================================================
FILE: hooks.go
================================================
package fiber
import (
"slices"
"github.com/gofiber/fiber/v3/log"
)
type (
// OnRouteHandler defines the hook signature invoked whenever a route is registered.
OnRouteHandler = func(Route) error
// OnNameHandler shares the OnRouteHandler signature for route naming callbacks.
OnNameHandler = OnRouteHandler
// OnGroupHandler defines the hook signature invoked whenever a group is registered.
OnGroupHandler = func(Group) error
// OnGroupNameHandler shares the OnGroupHandler signature for group naming callbacks.
OnGroupNameHandler = OnGroupHandler
// OnListenHandler runs when the application begins listening and receives the listener details.
OnListenHandler = func(ListenData) error
// OnPreStartupMessageHandler runs before Fiber prints the startup banner.
OnPreStartupMessageHandler = func(*PreStartupMessageData) error
// OnPostStartupMessageHandler runs after Fiber prints (or skips) the startup banner.
OnPostStartupMessageHandler = func(*PostStartupMessageData) error
// OnPreShutdownHandler runs before the application shuts down.
OnPreShutdownHandler = func() error
// OnPostShutdownHandler runs after shutdown and receives the shutdown result.
OnPostShutdownHandler = func(error) error
// OnForkHandler runs inside a forked worker process and receives the worker ID.
OnForkHandler = func(int) error
// OnMountHandler runs after a sub-application mounts to a parent and receives the parent app reference.
OnMountHandler = func(*App) error
)
// Hooks is a struct to use it with App.
type Hooks struct {
// Embed app
app *App
// Hooks
onRoute []OnRouteHandler
onName []OnNameHandler
onGroup []OnGroupHandler
onGroupName []OnGroupNameHandler
onListen []OnListenHandler
onPreStartup []OnPreStartupMessageHandler
onPostStartup []OnPostStartupMessageHandler
onPreShutdown []OnPreShutdownHandler
onPostShutdown []OnPostShutdownHandler
onFork []OnForkHandler
onMount []OnMountHandler
}
type StartupMessageLevel int
const (
// StartupMessageLevelInfo represents informational startup message entries.
StartupMessageLevelInfo StartupMessageLevel = iota
// StartupMessageLevelWarning represents warning startup message entries.
StartupMessageLevelWarning
// StartupMessageLevelError represents error startup message entries.
StartupMessageLevelError
)
const errString = "ERROR"
// startupMessageEntry represents a single line of startup message information.
type startupMessageEntry struct {
key string
title string
value string
priority int
level StartupMessageLevel
}
// ListenData contains the listener metadata provided to OnListenHandler.
type ListenData struct {
ColorScheme Colors
Host string
Port string
Version string
AppName string
ChildPIDs []int
HandlerCount int
ProcessCount int
PID int
TLS bool
Prefork bool
}
// PreStartupMessageData contains metadata exposed to OnPreStartupMessage hooks.
type PreStartupMessageData struct {
*ListenData
// BannerHeader allows overriding the ASCII art banner displayed at startup.
BannerHeader string
entries []startupMessageEntry
// PreventDefault, when set to true, suppresses the default startup message.
PreventDefault bool
}
// AddInfo adds an informational entry to the startup message with "INFO" label.
func (sm *PreStartupMessageData) AddInfo(key, title, value string, priority ...int) {
pri := -1
if len(priority) > 0 {
pri = priority[0]
}
sm.addEntry(key, title, value, pri, StartupMessageLevelInfo)
}
// AddWarning adds a warning entry to the startup message with "WARNING" label.
func (sm *PreStartupMessageData) AddWarning(key, title, value string, priority ...int) {
pri := -1
if len(priority) > 0 {
pri = priority[0]
}
sm.addEntry(key, title, value, pri, StartupMessageLevelWarning)
}
// AddError adds an error entry to the startup message with "ERROR" label.
func (sm *PreStartupMessageData) AddError(key, title, value string, priority ...int) {
pri := -1
if len(priority) > 0 {
pri = priority[0]
}
sm.addEntry(key, title, value, pri, StartupMessageLevelError)
}
// EntryKeys returns all entry keys currently present in the startup message.
func (sm *PreStartupMessageData) EntryKeys() []string {
keys := make([]string, 0, len(sm.entries))
for _, entry := range sm.entries {
keys = append(keys, entry.key)
}
return keys
}
// ResetEntries removes all existing entries from the startup message.
func (sm *PreStartupMessageData) ResetEntries() {
sm.entries = sm.entries[:0]
}
// DeleteEntry removes a specific entry from the startup message by its key.
func (sm *PreStartupMessageData) DeleteEntry(key string) {
if sm.entries == nil {
return
}
for i, entry := range sm.entries {
if entry.key == key {
sm.entries = append(sm.entries[:i], sm.entries[i+1:]...)
return
}
}
}
func (sm *PreStartupMessageData) addEntry(key, title, value string, priority int, level StartupMessageLevel) {
if sm.entries == nil {
sm.entries = make([]startupMessageEntry, 0, 8)
}
for i, entry := range sm.entries {
if entry.key != key {
continue
}
sm.entries[i].value = value
sm.entries[i].title = title
sm.entries[i].level = level
sm.entries[i].priority = priority
return
}
sm.entries = append(sm.entries, startupMessageEntry{
key: key,
title: title,
value: value,
priority: priority,
level: level,
})
}
func newPreStartupMessageData(listenData *ListenData) *PreStartupMessageData {
return &PreStartupMessageData{ListenData: listenData}
}
// PostStartupMessageData contains metadata exposed to OnPostStartupMessage hooks.
type PostStartupMessageData struct {
*ListenData
// Disabled indicates whether the startup message was disabled via configuration.
Disabled bool
// IsChild indicates whether the current process is a child in prefork mode.
IsChild bool
// Prevented indicates whether the startup message was suppressed by a pre-startup hook using PreventDefault property.
Prevented bool
}
func newPostStartupMessageData(listenData *ListenData, disabled, isChild, prevented bool) *PostStartupMessageData {
clone := *listenData
if len(listenData.ChildPIDs) > 0 {
clone.ChildPIDs = slices.Clone(listenData.ChildPIDs)
}
return &PostStartupMessageData{
ListenData: &clone,
Disabled: disabled,
IsChild: isChild,
Prevented: prevented,
}
}
func newHooks(app *App) *Hooks {
return &Hooks{
app: app,
onRoute: make([]OnRouteHandler, 0),
onGroup: make([]OnGroupHandler, 0),
onGroupName: make([]OnGroupNameHandler, 0),
onName: make([]OnNameHandler, 0),
onListen: make([]OnListenHandler, 0),
onPreStartup: make([]OnPreStartupMessageHandler, 0),
onPostStartup: make([]OnPostStartupMessageHandler, 0),
onPreShutdown: make([]OnPreShutdownHandler, 0),
onPostShutdown: make([]OnPostShutdownHandler, 0),
onFork: make([]OnForkHandler, 0),
onMount: make([]OnMountHandler, 0),
}
}
// OnRoute is a hook to execute user functions on each route registration.
// Also you can get route properties by route parameter.
func (h *Hooks) OnRoute(handler ...OnRouteHandler) {
h.app.mutex.Lock()
h.onRoute = append(h.onRoute, handler...)
h.app.mutex.Unlock()
}
// OnName is a hook to execute user functions on each route naming.
// Also you can get route properties by route parameter.
//
// WARN: OnName only works with naming routes, not groups.
func (h *Hooks) OnName(handler ...OnNameHandler) {
h.app.mutex.Lock()
h.onName = append(h.onName, handler...)
h.app.mutex.Unlock()
}
// OnGroup is a hook to execute user functions on each group registration.
// Also you can get group properties by group parameter.
func (h *Hooks) OnGroup(handler ...OnGroupHandler) {
h.app.mutex.Lock()
h.onGroup = append(h.onGroup, handler...)
h.app.mutex.Unlock()
}
// OnGroupName is a hook to execute user functions on each group naming.
// Also you can get group properties by group parameter.
//
// WARN: OnGroupName only works with naming groups, not routes.
func (h *Hooks) OnGroupName(handler ...OnGroupNameHandler) {
h.app.mutex.Lock()
h.onGroupName = append(h.onGroupName, handler...)
h.app.mutex.Unlock()
}
// OnListen is a hook to execute user functions on Listen or Listener.
func (h *Hooks) OnListen(handler ...OnListenHandler) {
h.app.mutex.Lock()
h.onListen = append(h.onListen, handler...)
h.app.mutex.Unlock()
}
// OnPreStartupMessage is a hook to execute user functions before the startup message is printed.
func (h *Hooks) OnPreStartupMessage(handler ...OnPreStartupMessageHandler) {
h.app.mutex.Lock()
h.onPreStartup = append(h.onPreStartup, handler...)
h.app.mutex.Unlock()
}
// OnPostStartupMessage is a hook to execute user functions after the startup message is printed (or skipped).
func (h *Hooks) OnPostStartupMessage(handler ...OnPostStartupMessageHandler) {
h.app.mutex.Lock()
h.onPostStartup = append(h.onPostStartup, handler...)
h.app.mutex.Unlock()
}
// OnPreShutdown is a hook to execute user functions before Shutdown.
func (h *Hooks) OnPreShutdown(handler ...OnPreShutdownHandler) {
h.app.mutex.Lock()
h.onPreShutdown = append(h.onPreShutdown, handler...)
h.app.mutex.Unlock()
}
// OnPostShutdown is a hook to execute user functions after Shutdown.
func (h *Hooks) OnPostShutdown(handler ...OnPostShutdownHandler) {
h.app.mutex.Lock()
h.onPostShutdown = append(h.onPostShutdown, handler...)
h.app.mutex.Unlock()
}
// OnFork is a hook to execute user function after fork process.
func (h *Hooks) OnFork(handler ...OnForkHandler) {
h.app.mutex.Lock()
h.onFork = append(h.onFork, handler...)
h.app.mutex.Unlock()
}
// OnMount is a hook to execute user function after mounting process.
// The mount event is fired when sub-app is mounted on a parent app. The parent app is passed as a parameter.
// It works for app and group mounting.
func (h *Hooks) OnMount(handler ...OnMountHandler) {
h.app.mutex.Lock()
h.onMount = append(h.onMount, handler...)
h.app.mutex.Unlock()
}
func (h *Hooks) executeOnRouteHooks(route *Route) error {
if route == nil {
return nil
}
cloned := *route
// Check mounting
if h.app.mountFields.mountPath != "" {
cloned.path = h.app.mountFields.mountPath + cloned.path
cloned.Path = cloned.path
}
for _, v := range h.onRoute {
if err := v(cloned); err != nil {
return err
}
}
return nil
}
func (h *Hooks) executeOnNameHooks(route *Route) error {
if route == nil {
return nil
}
cloned := *route
// Check mounting
if h.app.mountFields.mountPath != "" {
cloned.path = h.app.mountFields.mountPath + cloned.path
cloned.Path = cloned.path
}
for _, v := range h.onName {
if err := v(cloned); err != nil {
return err
}
}
return nil
}
func (h *Hooks) executeOnGroupHooks(group Group) error {
// Check mounting
if h.app.mountFields.mountPath != "" {
group.Prefix = h.app.mountFields.mountPath + group.Prefix
}
for _, v := range h.onGroup {
if err := v(group); err != nil {
return err
}
}
return nil
}
func (h *Hooks) executeOnGroupNameHooks(group Group) error {
// Check mounting
if h.app.mountFields.mountPath != "" {
group.Prefix = h.app.mountFields.mountPath + group.Prefix
}
for _, v := range h.onGroupName {
if err := v(group); err != nil {
return err
}
}
return nil
}
func (h *Hooks) executeOnListenHooks(listenData *ListenData) error {
for _, v := range h.onListen {
if err := v(*listenData); err != nil {
return err
}
}
return nil
}
func (h *Hooks) executeOnPreStartupMessageHooks(data *PreStartupMessageData) error {
for _, handler := range h.onPreStartup {
if err := handler(data); err != nil {
return err
}
}
return nil
}
func (h *Hooks) executeOnPostStartupMessageHooks(data *PostStartupMessageData) error {
for _, handler := range h.onPostStartup {
if err := handler(data); err != nil {
return err
}
}
return nil
}
func (h *Hooks) executeOnPreShutdownHooks() {
for _, v := range h.onPreShutdown {
if err := v(); err != nil {
log.Errorf("failed to call pre shutdown hook: %v", err)
}
}
}
func (h *Hooks) executeOnPostShutdownHooks(err error) {
for _, v := range h.onPostShutdown {
if hookErr := v(err); hookErr != nil {
log.Errorf("failed to call post shutdown hook: %v", hookErr)
}
}
}
func (h *Hooks) executeOnForkHooks(pid int) {
for _, v := range h.onFork {
if err := v(pid); err != nil {
log.Errorf("failed to call fork hook: %v", err)
}
}
}
func (h *Hooks) executeOnMountHooks(app *App) error {
for _, v := range h.onMount {
if err := v(app); err != nil {
return err
}
}
return nil
}
================================================
FILE: hooks_test.go
================================================
package fiber
import (
"bytes"
"errors"
"os"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/bytebufferpool"
"github.com/gofiber/fiber/v3/log"
)
const testMountPath = "/api"
func testSimpleHandler(c Ctx) error {
return c.SendString("simple")
}
func Test_Hook_OnRoute(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnRoute(func(r Route) error {
require.Empty(t, r.Name)
return nil
})
app.Get("/", testSimpleHandler).Name("x")
subApp := New()
subApp.Get("/test", testSimpleHandler)
app.Use("/sub", subApp)
}
func Test_Hook_OnRoute_Mount(t *testing.T) {
t.Parallel()
app := New()
subApp := New()
app.Use("/sub", subApp)
subApp.Hooks().OnRoute(func(r Route) error {
require.Equal(t, "/sub/test", r.Path)
return nil
})
app.Hooks().OnRoute(func(r Route) error {
require.Equal(t, "/", r.Path)
return nil
})
app.Get("/", testSimpleHandler).Name("x")
subApp.Get("/test", testSimpleHandler)
}
func Test_Hook_OnName(t *testing.T) {
t.Parallel()
app := New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Hooks().OnName(func(r Route) error {
_, err := buf.WriteString(r.Name)
require.NoError(t, err)
return nil
})
app.Get("/", testSimpleHandler).Name("index")
subApp := New()
subApp.Get("/test", testSimpleHandler)
subApp.Get("/test2", testSimpleHandler)
app.Use("/sub", subApp)
require.Equal(t, "index", buf.String())
}
func Test_Hook_OnName_Error(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnName(func(_ Route) error {
return errors.New("unknown error")
})
require.PanicsWithError(t, "unknown error", func() {
app.Get("/", testSimpleHandler).Name("index")
})
}
func Test_Hook_OnGroup(t *testing.T) {
t.Parallel()
app := New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Hooks().OnGroup(func(g Group) error {
_, err := buf.WriteString(g.Prefix)
require.NoError(t, err)
return nil
})
grp := app.Group("/x").Name("x.")
grp.Group("/a")
require.Equal(t, "/x/x/a", buf.String())
}
func Test_Hook_OnGroup_Mount(t *testing.T) {
t.Parallel()
app := New()
micro := New()
micro.Use("/john", app)
app.Hooks().OnGroup(func(g Group) error {
require.Equal(t, "/john/v1", g.Prefix)
return nil
})
v1 := app.Group("/v1")
v1.Get("/doe", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
}
func Test_Hook_OnGroupName(t *testing.T) {
t.Parallel()
app := New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
buf2 := bytebufferpool.Get()
defer bytebufferpool.Put(buf2)
app.Hooks().OnGroupName(func(g Group) error {
_, err := buf.WriteString(g.name)
require.NoError(t, err)
return nil
})
app.Hooks().OnName(func(r Route) error {
_, err := buf2.WriteString(r.Name)
require.NoError(t, err)
return nil
})
grp := app.Group("/x").Name("x.")
grp.Get("/test", testSimpleHandler).Name("test")
grp.Get("/test2", testSimpleHandler)
require.Equal(t, "x.", buf.String())
require.Equal(t, "x.test", buf2.String())
}
func Test_Hook_OnGroupName_Error(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnGroupName(func(_ Group) error {
return errors.New("unknown error")
})
require.PanicsWithError(t, "unknown error", func() {
_ = app.Group("/x").Name("x.")
})
}
func Test_Hook_OnPreShutdown(t *testing.T) {
t.Parallel()
app := New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Hooks().OnPreShutdown(func() error {
_, err := buf.WriteString("pre-shutdown")
require.NoError(t, err)
return nil
})
require.NoError(t, app.Shutdown())
require.Equal(t, "pre-shutdown", buf.String())
}
func Test_Hook_OnPostShutdown(t *testing.T) {
t.Run("should execute post shutdown hook with error", func(t *testing.T) {
app := New()
expectedErr := errors.New("test shutdown error")
hookCalled := make(chan error, 1)
defer close(hookCalled)
app.Hooks().OnPostShutdown(func(err error) error {
hookCalled <- err
return nil
})
go func() {
if err := app.Listen(":0"); err != nil {
return
}
}()
time.Sleep(100 * time.Millisecond)
app.hooks.executeOnPostShutdownHooks(expectedErr)
select {
case err := <-hookCalled:
require.Equal(t, expectedErr, err)
case <-time.After(time.Second):
t.Fatal("hook execution timeout")
}
require.NoError(t, app.Shutdown())
})
t.Run("should execute multiple hooks in order", func(t *testing.T) {
app := New()
execution := make([]int, 0)
app.Hooks().OnPostShutdown(func(_ error) error {
execution = append(execution, 1)
return nil
})
app.Hooks().OnPostShutdown(func(_ error) error {
execution = append(execution, 2)
return nil
})
app.hooks.executeOnPostShutdownHooks(nil)
require.Len(t, execution, 2, "expected 2 hooks to execute")
require.Equal(t, []int{1, 2}, execution, "hooks executed in wrong order")
})
t.Run("should handle hook error", func(_ *testing.T) {
app := New()
hookErr := errors.New("hook error")
app.Hooks().OnPostShutdown(func(_ error) error {
return hookErr
})
// Should not panic
app.hooks.executeOnPostShutdownHooks(nil)
})
}
func Test_Hook_OnListen(t *testing.T) {
t.Parallel()
app := New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Hooks().OnListen(func(_ ListenData) error {
_, err := buf.WriteString("ready")
require.NoError(t, err)
return nil
})
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0"))
require.Equal(t, "ready", buf.String())
}
func Test_ListenDataMetadata(t *testing.T) {
t.Parallel()
app := New(Config{AppName: "meta"})
app.handlersCount = 42
cfg := ListenConfig{EnablePrefork: true}
childPIDs := []int{11, 22}
listenData := app.prepareListenData(":3030", true, &cfg, childPIDs)
app.Hooks().OnListen(func(data ListenData) error {
require.Equal(t, globalIpv4Addr, data.Host)
require.Equal(t, "3030", data.Port)
require.True(t, data.TLS)
require.Equal(t, Version, data.Version)
require.Equal(t, "meta", data.AppName)
require.Equal(t, 42, data.HandlerCount)
require.Equal(t, runtime.GOMAXPROCS(0), data.ProcessCount)
require.Equal(t, os.Getpid(), data.PID)
require.True(t, data.Prefork)
require.Equal(t, childPIDs, data.ChildPIDs)
require.Equal(t, app.config.ColorScheme, data.ColorScheme)
return nil
})
app.runOnListenHooks(listenData)
app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error {
require.Equal(t, globalIpv4Addr, data.Host)
require.Equal(t, "3030", data.Port)
require.True(t, data.TLS)
require.Equal(t, Version, data.Version)
require.Equal(t, "meta", data.AppName)
require.Equal(t, 42, data.HandlerCount)
require.Equal(t, runtime.GOMAXPROCS(0), data.ProcessCount)
require.Equal(t, os.Getpid(), data.PID)
require.True(t, data.Prefork)
require.Equal(t, childPIDs, data.ChildPIDs)
require.Equal(t, app.config.ColorScheme, data.ColorScheme)
data.ResetEntries()
data.AddInfo("custom", "Custom Info", "value", 3)
data.AddInfo("other", "Other Info", "value", 2)
return nil
})
pre := newPreStartupMessageData(listenData)
require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre))
require.Equal(t, "value", pre.entries[0].value)
require.Equal(t, "Custom Info", pre.entries[0].title)
require.Equal(t, 3, pre.entries[0].priority)
require.Equal(t, "value", pre.entries[1].value)
require.Equal(t, "Other Info", pre.entries[1].title)
require.Equal(t, 2, pre.entries[1].priority)
require.False(t, pre.PreventDefault)
}
func Test_ListenData_Hook_HelperFunctions(t *testing.T) {
t.Parallel()
t.Run("EntryKeys", func(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error {
data.ResetEntries()
data.AddInfo("key1", "Title 1", "Value 1", 1)
data.AddInfo("key2", "Title 2", "Value 2", 2)
keys := data.EntryKeys()
require.Len(t, keys, 2)
require.Equal(t, "key1", keys[0])
require.Equal(t, "key2", keys[1])
return nil
})
pre := newPreStartupMessageData(&ListenData{})
require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre))
})
t.Run("ResetEntries", func(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error {
data.ResetEntries()
data.AddInfo("key1", "Title 1", "Value 1", 1)
data.AddInfo("key2", "Title 2", "Value 2", 2)
require.Len(t, data.entries, 2)
data.ResetEntries()
require.Empty(t, data.entries)
return nil
})
pre := newPreStartupMessageData(&ListenData{})
require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre))
})
t.Run("AddInfo", func(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error {
data.ResetEntries()
data.AddInfo("key1", "Title 1", "Value 1", 1)
require.Len(t, data.entries, 1)
require.Equal(t, "key1", data.entries[0].key)
require.Equal(t, "Title 1", data.entries[0].title)
require.Equal(t, "Value 1", data.entries[0].value)
require.Equal(t, 1, data.entries[0].priority)
return nil
})
pre := newPreStartupMessageData(&ListenData{})
require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre))
})
t.Run("AddWarning", func(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error {
data.ResetEntries()
data.AddWarning("key1", "Title 1", "Value 1", 1)
require.Len(t, data.entries, 1)
require.Equal(t, "key1", data.entries[0].key)
require.Equal(t, "Title 1", data.entries[0].title)
require.Equal(t, "Value 1", data.entries[0].value)
require.Equal(t, 1, data.entries[0].priority)
return nil
})
pre := newPreStartupMessageData(&ListenData{})
require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre))
})
t.Run("AddError", func(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error {
data.ResetEntries()
data.AddError("key1", "Title 1", "Value 1", 1)
require.Len(t, data.entries, 1)
require.Equal(t, "key1", data.entries[0].key)
require.Equal(t, "Title 1", data.entries[0].title)
require.Equal(t, "Value 1", data.entries[0].value)
require.Equal(t, 1, data.entries[0].priority)
return nil
})
pre := newPreStartupMessageData(&ListenData{})
require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre))
})
t.Run("AddInfo-UpdateExisting", func(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error {
data.ResetEntries()
data.AddInfo("key1", "Title 1", "Value 1", 1)
data.AddInfo("key1", "Updated Title", "Updated Value", 2)
require.Len(t, data.entries, 1)
require.Equal(t, "key1", data.entries[0].key)
require.Equal(t, "Updated Title", data.entries[0].title)
require.Equal(t, "Updated Value", data.entries[0].value)
require.Equal(t, 2, data.entries[0].priority)
return nil
})
pre := newPreStartupMessageData(&ListenData{})
require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre))
})
t.Run("DeleteEntry", func(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error {
data.ResetEntries()
data.AddInfo("key1", "Title 1", "Value 1", 1)
data.AddInfo("key2", "Title 2", "Value 2", 2)
require.Len(t, data.entries, 2)
data.DeleteEntry("key1")
require.Len(t, data.entries, 1)
require.Equal(t, "key2", data.entries[0].key)
data.DeleteEntry("key-not-exist") // should not panic
require.Len(t, data.entries, 1)
return nil
})
pre := newPreStartupMessageData(&ListenData{})
require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre))
})
}
func Test_Hook_OnListenPrefork(t *testing.T) {
t.Parallel()
app := New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Hooks().OnListen(func(_ ListenData) error {
_, err := buf.WriteString("ready")
require.NoError(t, err)
return nil
})
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true}))
require.Equal(t, "ready", buf.String())
}
func Test_Hook_OnHook(t *testing.T) {
app := New()
// Reset test var
testPreforkMaster = true
testOnPrefork = true
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
app.Hooks().OnFork(func(pid int) error {
require.Equal(t, 1, pid)
return nil
})
require.NoError(t, app.prefork(":0", nil, &ListenConfig{DisableStartupMessage: true, EnablePrefork: true}))
}
func Test_Hook_OnMount(t *testing.T) {
t.Parallel()
app := New()
app.Get("/", testSimpleHandler).Name("x")
subApp := New()
subApp.Get("/test", testSimpleHandler)
subApp.Hooks().OnMount(func(parent *App) error {
require.Empty(t, parent.mountFields.mountPath)
return nil
})
app.Use("/sub", subApp)
}
func Test_executeOnRouteHooks_ErrorWithMount(t *testing.T) {
t.Parallel()
app := New()
app.mountFields.mountPath = testMountPath
var received string
app.Hooks().OnRoute(func(r Route) error {
received = r.Path
return errors.New("hook error")
})
err := app.hooks.executeOnRouteHooks(&Route{Path: "/foo", path: "/foo"})
require.Equal(t, testMountPath+"/foo", received)
require.EqualError(t, err, "hook error")
}
func Test_executeOnNameHooks_ErrorWithMount(t *testing.T) {
t.Parallel()
app := New()
app.mountFields.mountPath = testMountPath
var received string
app.Hooks().OnName(func(r Route) error {
received = r.Path
return errors.New("name error")
})
err := app.hooks.executeOnNameHooks(&Route{Path: "/bar", path: "/bar"})
require.Equal(t, testMountPath+"/bar", received)
require.EqualError(t, err, "name error")
}
func Test_executeOnGroupHooks_ErrorWithMount(t *testing.T) {
t.Parallel()
app := New()
app.mountFields.mountPath = testMountPath
var prefix string
app.Hooks().OnGroup(func(g Group) error {
prefix = g.Prefix
return errors.New("group error")
})
err := app.hooks.executeOnGroupHooks(Group{Prefix: "/grp"})
require.Equal(t, testMountPath+"/grp", prefix)
require.EqualError(t, err, "group error")
}
func Test_executeOnGroupNameHooks_ErrorWithMount(t *testing.T) {
t.Parallel()
app := New()
app.mountFields.mountPath = testMountPath
var prefix string
app.Hooks().OnGroupName(func(g Group) error {
prefix = g.Prefix
return errors.New("group name error")
})
err := app.hooks.executeOnGroupNameHooks(Group{Prefix: "/grp"})
require.Equal(t, testMountPath+"/grp", prefix)
require.EqualError(t, err, "group name error")
}
func Test_executeOnListenHooks_Error(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnListen(func(_ ListenData) error {
return errors.New("listen error")
})
err := app.hooks.executeOnListenHooks(&ListenData{Host: "127.0.0.1", Port: "0"})
require.EqualError(t, err, "listen error")
}
func Test_executeOnPreStartupMessageHooks_Error(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnPreStartupMessage(func(_ *PreStartupMessageData) error {
return errors.New("pre startup message error")
})
err := app.hooks.executeOnPreStartupMessageHooks(newPreStartupMessageData(&ListenData{}))
require.EqualError(t, err, "pre startup message error")
}
func Test_executeOnPostStartupMessageHooks_Error(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnPostStartupMessage(func(_ *PostStartupMessageData) error {
return errors.New("post startup message error")
})
err := app.hooks.executeOnPostStartupMessageHooks(newPostStartupMessageData(&ListenData{}, false, false, false))
require.EqualError(t, err, "post startup message error")
}
func Test_executeOnPreShutdownHooks_Error(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnPreShutdown(func() error {
return errors.New("pre error")
})
var buf bytes.Buffer
log.SetOutput(&buf)
app.hooks.executeOnPreShutdownHooks()
require.NotZero(t, buf.Len())
}
func Test_executeOnForkHooks_Error(t *testing.T) {
t.Parallel()
app := New()
app.Hooks().OnFork(func(pid int) error {
require.Equal(t, 1, pid)
return errors.New("fork error")
})
var buf bytes.Buffer
log.SetOutput(&buf)
app.hooks.executeOnForkHooks(1)
require.NotZero(t, buf.Len())
}
func Test_executeOnMountHooks_Error(t *testing.T) {
t.Parallel()
app := New()
parent := New()
app.Hooks().OnMount(func(a *App) error {
require.Equal(t, parent, a)
return errors.New("mount error")
})
err := app.hooks.executeOnMountHooks(parent)
require.EqualError(t, err, "mount error")
}
================================================
FILE: internal/memory/memory.go
================================================
// Package memory provides a high-performance in-memory storage that can store
// any type without encoding overhead. Unlike the standard storage interface,
// this storage works directly with Go types for maximum speed.
//
// # Safety Considerations
//
// This storage automatically performs defensive copying for:
// - String keys: Copied to prevent corruption from pooled buffers
// - []byte values: Copied on both Set and Get to prevent external mutation
//
// For other types (structs, ints, etc.), Go's value semantics provide natural
// protection. However, if storing pointers or slices of non-byte types,
// callers are responsible for not mutating the underlying data.
//
// This storage is primarily used internally by middleware for performance-
// critical operations where the stored data types are known and controlled.
package memory
import (
"sync"
"time"
"github.com/gofiber/utils/v2"
)
// Storage stores arbitrary values in memory for use in tests and benchmarks.
type Storage struct {
data map[string]item // data
mu sync.RWMutex
}
type item struct {
v any // val
// max value is 4294967295 -> Sun Feb 07 2106 06:28:15 GMT+0000
e uint32 // exp
}
// New constructs an in-memory Storage initialized with a background GC loop.
func New() *Storage {
store := &Storage{
data: make(map[string]item),
}
utils.StartTimeStampUpdater()
go store.gc(1 * time.Second)
return store
}
// Get retrieves the value stored under key, returning nil when the entry does
// not exist or has expired.
//
// For []byte values, this returns a defensive copy to prevent callers from
// mutating the stored data. Other types are returned as-is.
func (s *Storage) Get(key string) any {
s.mu.RLock()
v, ok := s.data[key]
s.mu.RUnlock()
if !ok || v.e != 0 && v.e <= utils.Timestamp() {
return nil
}
// Defensive copy for byte slices to prevent external mutation
if b, ok := v.v.([]byte); ok {
return utils.CopyBytes(b)
}
return v.v
}
// Set stores val under key and applies the optional ttl before expiring the
// entry. A non-positive ttl keeps the item forever.
//
// String keys are defensively copied to prevent corruption from pooled buffers.
// []byte values are also copied to prevent external mutation of stored data.
// Other types are stored as-is (structs are copied by value automatically).
func (s *Storage) Set(key string, val any, ttl time.Duration) {
var exp uint32
if ttl > 0 {
exp = uint32(ttl.Seconds()) + utils.Timestamp()
}
// Defensive copies to prevent unsafe reuse from sync.Pool
keyCopy := utils.CopyString(key)
// Copy byte slices to prevent external mutation
if b, ok := val.([]byte); ok {
val = utils.CopyBytes(b)
}
i := item{e: exp, v: val}
s.mu.Lock()
s.data[keyCopy] = i
s.mu.Unlock()
}
// Delete removes key and its associated value from the storage.
func (s *Storage) Delete(key string) {
s.mu.Lock()
delete(s.data, key)
s.mu.Unlock()
}
// Reset clears the storage by dropping every stored key.
func (s *Storage) Reset() {
nd := make(map[string]item)
s.mu.Lock()
s.data = nd
s.mu.Unlock()
}
func (s *Storage) gc(sleep time.Duration) {
ticker := time.NewTicker(sleep)
defer ticker.Stop()
var expired []string
for range ticker.C {
ts := utils.Timestamp()
expired = expired[:0]
s.mu.RLock()
for key, v := range s.data {
if v.e != 0 && v.e <= ts {
expired = append(expired, key)
}
}
s.mu.RUnlock()
if len(expired) == 0 {
// avoid locking if nothing to delete
continue
}
s.mu.Lock()
// Double-checked locking.
// We might have replaced the item in the meantime.
for i := range expired {
v := s.data[expired[i]]
if v.e != 0 && v.e <= ts {
delete(s.data, expired[i])
}
}
s.mu.Unlock()
}
}
================================================
FILE: internal/memory/memory_test.go
================================================
package memory
import (
"testing"
"time"
"github.com/gofiber/utils/v2"
"github.com/stretchr/testify/require"
)
// go test -run Test_Memory -v -race
func Test_Memory(t *testing.T) {
t.Parallel()
store := New()
var (
key = "john-internal"
val any = []byte("doe")
exp = 1 * time.Second
)
// Set key with value
store.Set(key, val, 0)
result := store.Get(key)
require.Equal(t, val, result)
// Get non-existing key
result = store.Get("empty")
require.Nil(t, result)
// Set key with value and ttl
store.Set(key, val, exp)
time.Sleep(1100 * time.Millisecond)
result = store.Get(key)
require.Nil(t, result)
// Set key with value and no expiration
store.Set(key, val, 0)
result = store.Get(key)
require.Equal(t, val, result)
// Delete key
store.Delete(key)
result = store.Get(key)
require.Nil(t, result)
// Reset all keys
store.Set("john-reset", val, 0)
store.Set("doe-reset", val, 0)
store.Reset()
// Check if all keys are deleted
result = store.Get("john-reset")
require.Nil(t, result)
result = store.Get("doe-reset")
require.Nil(t, result)
}
// go test -v -run=^$ -bench=Benchmark_Memory -benchmem -count=4
func Benchmark_Memory(b *testing.B) {
keyLength := 1000
keys := make([]string, keyLength)
for i := range keyLength {
keys[i] = utils.UUIDv4()
}
value := []byte("joe")
ttl := 2 * time.Second
b.Run("fiber_memory", func(b *testing.B) {
d := New()
b.ReportAllocs()
for b.Loop() {
for _, key := range keys {
d.Set(key, value, ttl)
}
for _, key := range keys {
_ = d.Get(key)
}
for _, key := range keys {
d.Delete(key)
}
}
})
}
================================================
FILE: internal/storage/memory/config.go
================================================
package memory
import (
"time"
)
// Config defines the config for storage.
type Config struct {
// Time before deleting expired keys
//
// Default is 10 * time.Second
GCInterval time.Duration
}
// ConfigDefault is the default config
var ConfigDefault = Config{
GCInterval: 10 * time.Second,
}
// configDefault is a helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if int(cfg.GCInterval.Seconds()) <= 0 {
cfg.GCInterval = ConfigDefault.GCInterval
}
return cfg
}
================================================
FILE: internal/storage/memory/memory.go
================================================
// Package memory Is a copy of the storage memory from the external storage packet as a purpose to test the behavior
// in the unittests when using a storages from these packets
package memory
import (
"context"
"fmt"
"sync"
"time"
"github.com/gofiber/utils/v2"
)
// Storage provides an in-memory implementation of the storage interface for
// testing purposes.
type Storage struct {
db map[string]Entry
done chan struct{}
gcInterval time.Duration
mux sync.RWMutex
}
// Entry represents a value stored in memory along with its expiration.
type Entry struct {
data []byte
// max value is 4294967295 -> Sun Feb 07 2106 06:28:15 GMT+0000
expiry uint32
}
// New creates a new memory storage.
func New(config ...Config) *Storage {
// Set default config
cfg := configDefault(config...)
// Create storage
store := &Storage{
db: make(map[string]Entry),
gcInterval: cfg.GCInterval,
done: make(chan struct{}),
}
// Start garbage collector
utils.StartTimeStampUpdater()
go store.gc()
return store
}
// Get returns the stored value for key, ignoring missing or expired entries by
// returning nil.
func (s *Storage) Get(key string) ([]byte, error) {
if key == "" {
return nil, nil
}
s.mux.RLock()
v, ok := s.db[key]
s.mux.RUnlock()
if !ok || v.expiry != 0 && v.expiry <= utils.Timestamp() {
return nil, nil
}
// Return a copy to prevent callers from mutating stored data
return utils.CopyBytes(v.data), nil
}
// GetWithContext retrieves the value for the given key while honoring context
// cancellation.
func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) {
if err := wrapContextError(ctx, "get"); err != nil {
return nil, err
}
return s.Get(key)
}
// Set saves val under key and schedules it to expire after exp. A zero exp keeps
// the entry indefinitely.
func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
// Ain't Nobody Got Time For That
if key == "" || len(val) == 0 {
return nil
}
var expire uint32
if exp != 0 {
expire = uint32(exp.Seconds()) + utils.Timestamp()
}
// Copy both key and value to avoid unsafe reuse from sync.Pool
keyCopy := utils.CopyString(key)
valCopy := utils.CopyBytes(val)
e := Entry{data: valCopy, expiry: expire}
s.mux.Lock()
s.db[keyCopy] = e
s.mux.Unlock()
return nil
}
// SetWithContext sets the value for the given key while honoring context
// cancellation.
func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error {
if err := wrapContextError(ctx, "set"); err != nil {
return err
}
return s.Set(key, val, exp)
}
// Delete removes the value stored for key.
func (s *Storage) Delete(key string) error {
// Ain't Nobody Got Time For That
if key == "" {
return nil
}
s.mux.Lock()
delete(s.db, key)
s.mux.Unlock()
return nil
}
// DeleteWithContext removes the value for the given key while honoring
// context cancellation.
func (s *Storage) DeleteWithContext(ctx context.Context, key string) error {
if err := wrapContextError(ctx, "delete"); err != nil {
return err
}
return s.Delete(key)
}
// Reset clears all keys and values from the storage map.
func (s *Storage) Reset() error {
ndb := make(map[string]Entry)
s.mux.Lock()
s.db = ndb
s.mux.Unlock()
return nil
}
// ResetWithContext clears all stored keys while honoring context
// cancellation.
func (s *Storage) ResetWithContext(ctx context.Context) error {
if err := wrapContextError(ctx, "reset"); err != nil {
return err
}
return s.Reset()
}
// Close stops the background garbage collector and releases resources
// associated with the storage instance.
func (s *Storage) Close() error {
s.done <- struct{}{}
return nil
}
func (s *Storage) gc() {
ticker := time.NewTicker(s.gcInterval)
defer ticker.Stop()
var expired []string
for {
select {
case <-s.done:
return
case <-ticker.C:
ts := utils.Timestamp()
expired = expired[:0]
s.mux.RLock()
for id, v := range s.db {
if v.expiry != 0 && v.expiry < ts {
expired = append(expired, id)
}
}
s.mux.RUnlock()
if len(expired) == 0 {
// avoid locking if nothing to delete
continue
}
s.mux.Lock()
// Double-checked locking.
// We might have replaced the item in the meantime.
for i := range expired {
v := s.db[expired[i]]
if v.expiry != 0 && v.expiry <= ts {
delete(s.db, expired[i])
}
}
s.mux.Unlock()
}
}
}
// Conn returns the underlying storage map. The map must not be modified by
// callers.
func (s *Storage) Conn() map[string]Entry {
s.mux.RLock()
defer s.mux.RUnlock()
return s.db
}
// Keys returns all keys stored in the memory storage.
func (s *Storage) Keys() ([][]byte, error) {
s.mux.RLock()
defer s.mux.RUnlock()
if len(s.db) == 0 {
return nil, nil
}
ts := utils.Timestamp()
keys := make([][]byte, 0, len(s.db))
for key, v := range s.db {
// Filter out the expired keys
if v.expiry == 0 || v.expiry > ts {
keys = append(keys, []byte(key))
}
}
// Double check if no valid keys were found
if len(keys) == 0 {
return nil, nil
}
return keys, nil
}
func wrapContextError(ctx context.Context, op string) error {
if err := ctx.Err(); err != nil {
return fmt.Errorf("memory storage %s: %w", op, err)
}
return nil
}
================================================
FILE: internal/storage/memory/memory_test.go
================================================
package memory
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func Test_Storage_Memory_Set(t *testing.T) {
t.Parallel()
var (
testStore = New()
key = "john"
val = []byte("doe")
)
err := testStore.Set(key, val, 0)
require.NoError(t, err)
keys, err := testStore.Keys()
require.NoError(t, err)
require.Len(t, keys, 1)
}
func Test_Storage_Memory_SetWithContext(t *testing.T) {
t.Parallel()
var (
testStore = New()
key = "john"
val = []byte("doe")
)
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := testStore.SetWithContext(ctx, key, val, 0)
require.ErrorIs(t, err, context.Canceled)
keys, err := testStore.Keys()
require.NoError(t, err)
require.Nil(t, keys)
}
func Test_Storage_Memory_Set_Override(t *testing.T) {
t.Parallel()
var (
testStore = New()
key = "john"
val = []byte("doe")
)
err := testStore.Set(key, val, 0)
require.NoError(t, err)
err = testStore.Set(key, val, 0)
require.NoError(t, err)
keys, err := testStore.Keys()
require.NoError(t, err)
require.Len(t, keys, 1)
}
func Test_Storage_Memory_Get(t *testing.T) {
t.Parallel()
var (
testStore = New()
key = "john"
val = []byte("doe")
)
err := testStore.Set(key, val, 0)
require.NoError(t, err)
result, err := testStore.Get(key)
require.NoError(t, err)
require.Equal(t, val, result)
keys, err := testStore.Keys()
require.NoError(t, err)
require.Len(t, keys, 1)
}
func Test_Storage_Memory_GetWithContext(t *testing.T) {
t.Parallel()
var (
testStore = New()
key = "john"
val = []byte("doe")
)
err := testStore.Set(key, val, 0)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
cancel()
result, err := testStore.GetWithContext(ctx, key)
require.ErrorIs(t, err, context.Canceled)
require.Nil(t, result)
keys, err := testStore.Keys()
require.NoError(t, err)
require.Len(t, keys, 1)
}
func Test_Storage_Memory_Set_Expiration(t *testing.T) {
t.Parallel()
var (
testStore = New(Config{
GCInterval: 300 * time.Millisecond,
})
key = "john"
val = []byte("doe")
exp = 1 * time.Second
)
err := testStore.Set(key, val, exp)
require.NoError(t, err)
// interval + expire + buffer
time.Sleep(1500 * time.Millisecond)
result, err := testStore.Get(key)
require.NoError(t, err)
require.Empty(t, result)
keys, err := testStore.Keys()
require.NoError(t, err)
require.Nil(t, keys)
}
func Test_Storage_Memory_Set_Long_Expiration_with_Keys(t *testing.T) {
t.Parallel()
var (
testStore = New()
key = "john"
val = []byte("doe")
exp = 3 * time.Second
)
keys, err := testStore.Keys()
require.NoError(t, err)
require.Nil(t, keys)
err = testStore.Set(key, val, exp)
require.NoError(t, err)
time.Sleep(1100 * time.Millisecond)
keys, err = testStore.Keys()
require.NoError(t, err)
require.Len(t, keys, 1)
time.Sleep(4000 * time.Millisecond)
result, err := testStore.Get(key)
require.NoError(t, err)
require.Empty(t, result)
keys, err = testStore.Keys()
require.NoError(t, err)
require.Nil(t, keys)
}
func Test_Storage_Memory_Get_NotExist(t *testing.T) {
t.Parallel()
testStore := New()
result, err := testStore.Get("notexist")
require.NoError(t, err)
require.Empty(t, result)
keys, err := testStore.Keys()
require.NoError(t, err)
require.Nil(t, keys)
}
func Test_Storage_Memory_Delete(t *testing.T) {
t.Parallel()
var (
testStore = New()
key = "john"
val = []byte("doe")
)
err := testStore.Set(key, val, 0)
require.NoError(t, err)
keys, err := testStore.Keys()
require.NoError(t, err)
require.Len(t, keys, 1)
err = testStore.Delete(key)
require.NoError(t, err)
result, err := testStore.Get(key)
require.NoError(t, err)
require.Empty(t, result)
keys, err = testStore.Keys()
require.NoError(t, err)
require.Nil(t, keys)
}
func Test_Storage_Memory_DeleteWithContext(t *testing.T) {
t.Parallel()
var (
testStore = New()
key = "john"
val = []byte("doe")
)
err := testStore.Set(key, val, 0)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
cancel()
err = testStore.DeleteWithContext(ctx, key)
require.ErrorIs(t, err, context.Canceled)
result, err := testStore.Get(key)
require.NoError(t, err)
require.Equal(t, val, result)
keys, err := testStore.Keys()
require.NoError(t, err)
require.Len(t, keys, 1)
}
func Test_Storage_Memory_Reset(t *testing.T) {
t.Parallel()
testStore := New()
val := []byte("doe")
err := testStore.Set("john1", val, 0)
require.NoError(t, err)
err = testStore.Set("john2", val, 0)
require.NoError(t, err)
keys, err := testStore.Keys()
require.NoError(t, err)
require.Len(t, keys, 2)
err = testStore.Reset()
require.NoError(t, err)
result, err := testStore.Get("john1")
require.NoError(t, err)
require.Empty(t, result)
result, err = testStore.Get("john2")
require.NoError(t, err)
require.Empty(t, result)
keys, err = testStore.Keys()
require.NoError(t, err)
require.Nil(t, keys)
}
func Test_Storage_Memory_ResetWithContext(t *testing.T) {
t.Parallel()
testStore := New()
val := []byte("doe")
err := testStore.Set("john1", val, 0)
require.NoError(t, err)
err = testStore.Set("john2", val, 0)
require.NoError(t, err)
keys, err := testStore.Keys()
require.NoError(t, err)
require.Len(t, keys, 2)
ctx, cancel := context.WithCancel(context.Background())
cancel()
err = testStore.ResetWithContext(ctx)
require.ErrorIs(t, err, context.Canceled)
result, err := testStore.Get("john1")
require.NoError(t, err)
require.Equal(t, val, result)
result, err = testStore.Get("john2")
require.NoError(t, err)
require.Equal(t, val, result)
keys, err = testStore.Keys()
require.NoError(t, err)
require.Len(t, keys, 2)
}
func Test_Storage_Memory_Close(t *testing.T) {
t.Parallel()
testStore := New()
require.NoError(t, testStore.Close())
}
func Test_Storage_Memory_Conn(t *testing.T) {
t.Parallel()
testStore := New()
require.NotNil(t, testStore.Conn())
}
// Benchmarks for Set operation
func Benchmark_Memory_Set(b *testing.B) {
testStore := New()
b.ReportAllocs()
for b.Loop() {
_ = testStore.Set("john", []byte("doe"), 0) //nolint:errcheck // error not needed for benchmark
}
}
func Benchmark_Memory_Set_Parallel(b *testing.B) {
testStore := New()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = testStore.Set("john", []byte("doe"), 0) //nolint:errcheck // error not needed for benchmark
}
})
}
func Benchmark_Memory_Set_Asserted(b *testing.B) {
testStore := New()
b.ReportAllocs()
for b.Loop() {
err := testStore.Set("john", []byte("doe"), 0)
require.NoError(b, err)
}
}
func Benchmark_Memory_Set_Asserted_Parallel(b *testing.B) {
testStore := New()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
err := testStore.Set("john", []byte("doe"), 0)
require.NoError(b, err)
}
})
}
// Benchmarks for Get operation
func Benchmark_Memory_Get(b *testing.B) {
testStore := New()
err := testStore.Set("john", []byte("doe"), 0)
require.NoError(b, err)
b.ReportAllocs()
for b.Loop() {
_, _ = testStore.Get("john") //nolint:errcheck // error not needed for benchmark
}
}
func Benchmark_Memory_Get_Parallel(b *testing.B) {
testStore := New()
err := testStore.Set("john", []byte("doe"), 0)
require.NoError(b, err)
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, _ = testStore.Get("john") //nolint:errcheck // error not needed for benchmark
}
})
}
func Benchmark_Memory_Get_Asserted(b *testing.B) {
testStore := New()
err := testStore.Set("john", []byte("doe"), 0)
require.NoError(b, err)
b.ReportAllocs()
for b.Loop() {
_, err := testStore.Get("john")
require.NoError(b, err)
}
}
func Benchmark_Memory_Get_Asserted_Parallel(b *testing.B) {
testStore := New()
err := testStore.Set("john", []byte("doe"), 0)
require.NoError(b, err)
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := testStore.Get("john")
require.NoError(b, err)
}
})
}
// Benchmarks for SetAndDelete operation
func Benchmark_Memory_SetAndDelete(b *testing.B) {
testStore := New()
b.ReportAllocs()
for b.Loop() {
_ = testStore.Set("john", []byte("doe"), 0) //nolint:errcheck // error not needed for benchmark
_ = testStore.Delete("john") //nolint:errcheck // error not needed for benchmark
}
}
func Benchmark_Memory_SetAndDelete_Parallel(b *testing.B) {
testStore := New()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = testStore.Set("john", []byte("doe"), 0) //nolint:errcheck // error not needed for benchmark
_ = testStore.Delete("john") //nolint:errcheck // error not needed for benchmark
}
})
}
func Benchmark_Memory_SetAndDelete_Asserted(b *testing.B) {
testStore := New()
b.ReportAllocs()
for b.Loop() {
err := testStore.Set("john", []byte("doe"), 0)
require.NoError(b, err)
err = testStore.Delete("john")
require.NoError(b, err)
}
}
func Benchmark_Memory_SetAndDelete_Asserted_Parallel(b *testing.B) {
testStore := New()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
err := testStore.Set("john", []byte("doe"), 0)
require.NoError(b, err)
err = testStore.Delete("john")
require.NoError(b, err)
}
})
}
================================================
FILE: internal/tlstest/tls.go
================================================
package tlstest
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"math/big"
"net"
"time"
)
var errAppendCACert = errors.New("failed to append CA certificate to certificate pool")
// GetTLSConfigs generates TLS configurations for a test server and client that
// trust each other using an in-memory certificate authority.
func GetTLSConfigs() (serverTLSConf, clientTLSConf *tls.Config, err error) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming server and client TLS configurations along with the error
// set up our CA certificate
ca := &x509.Certificate{
SerialNumber: big.NewInt(2021),
Subject: pkix.Name{
Organization: []string{"Fiber"},
Country: []string{"NL"},
Province: []string{""},
Locality: []string{"Amsterdam"},
StreetAddress: []string{"Huidenstraat"},
PostalCode: []string{"1011 AA"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
// create our private and public key
caPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, fmt.Errorf("generate CA key: %w", err)
}
// create the CA
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivateKey.PublicKey, caPrivateKey)
if err != nil {
return nil, nil, fmt.Errorf("create CA certificate: %w", err)
}
// pem encode
var caPEM bytes.Buffer
if err = pem.Encode(&caPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
}); err != nil {
return nil, nil, fmt.Errorf("encode CA cert: %w", err)
}
var caPrivKeyPEM bytes.Buffer
if err = pem.Encode(&caPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(caPrivateKey),
}); err != nil {
return nil, nil, fmt.Errorf("encode CA private key: %w", err)
}
// set up our server certificate
cert := &x509.Certificate{
SerialNumber: big.NewInt(2021),
Subject: pkix.Name{
Organization: []string{"Fiber"},
Country: []string{"NL"},
Province: []string{""},
Locality: []string{"Amsterdam"},
StreetAddress: []string{"Huidenstraat"},
PostalCode: []string{"1011 AA"},
},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
certPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, fmt.Errorf("generate server key: %w", err)
}
certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivateKey.PublicKey, caPrivateKey)
if err != nil {
return nil, nil, fmt.Errorf("create server certificate: %w", err)
}
var certPEM bytes.Buffer
if err = pem.Encode(&certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
}); err != nil {
return nil, nil, fmt.Errorf("encode server cert: %w", err)
}
var certPrivateKeyPEM bytes.Buffer
if err = pem.Encode(&certPrivateKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivateKey),
}); err != nil {
return nil, nil, fmt.Errorf("encode server private key: %w", err)
}
serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivateKeyPEM.Bytes())
if err != nil {
return nil, nil, fmt.Errorf("load server key pair: %w", err)
}
serverTLSConf = &tls.Config{
Certificates: []tls.Certificate{serverCert},
MinVersion: tls.VersionTLS12,
}
certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(caPEM.Bytes()); !ok {
return nil, nil, errAppendCACert
}
clientTLSConf = &tls.Config{
RootCAs: certPool,
MinVersion: tls.VersionTLS12,
}
return serverTLSConf, clientTLSConf, nil
}
================================================
FILE: listen.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net"
"os"
"path/filepath"
"reflect"
"runtime"
"slices"
"sort"
"strings"
"text/tabwriter"
"time"
"github.com/mattn/go-colorable"
"github.com/mattn/go-isatty"
"golang.org/x/crypto/acme/autocert"
"github.com/gofiber/fiber/v3/log"
)
// Figlet text to show Fiber ASCII art on startup message
var figletFiberText = `
_______ __
/ ____(_) /_ ___ _____
/ /_ / / __ \/ _ \/ ___/
/ __/ / / /_/ / __/ /
/_/ /_/_.___/\___/_/ %s`
const (
globalIpv4Addr = "0.0.0.0"
)
// ListenConfig is a struct to customize startup of Fiber.
type ListenConfig struct {
// GracefulContext is a field to shutdown Fiber by given context gracefully.
//
// Default: nil
GracefulContext context.Context `json:"graceful_context"` //nolint:containedctx // It's needed to set context inside Listen.
// TLSConfigFunc allows customizing tls.Config as you want.
//
// Default: nil
TLSConfigFunc func(tlsConfig *tls.Config) `json:"tls_config_func"`
// TLSConfig allows providing a tls.Config used as the base for TLS settings.
// This enables external certificate providers via GetCertificate.
//
// Default: nil
TLSConfig *tls.Config `json:"tls_config"`
// ListenerFunc allows accessing and customizing net.Listener.
//
// Default: nil
ListenerAddrFunc func(addr net.Addr) `json:"listener_addr_func"`
// BeforeServeFunc allows customizing and accessing fiber app before serving the app.
//
// Default: nil
BeforeServeFunc func(app *App) error `json:"before_serve_func"`
// AutoCertManager manages TLS certificates automatically using the ACME protocol,
// Enables integration with Let's Encrypt or other ACME-compatible providers.
//
// Default: nil
AutoCertManager *autocert.Manager `json:"auto_cert_manager"`
// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), "unix" (Unix Domain Sockets)
// WARNING: When prefork is set to true, only "tcp4" and "tcp6" can be chosen.
//
// Default: NetworkTCP4
ListenerNetwork string `json:"listener_network"`
// CertFile is a path of certificate file.
// If you want to use TLS, you have to enter this field.
//
// Default : ""
CertFile string `json:"cert_file"`
// KeyFile is a path of certificate's private key.
// If you want to use TLS, you have to enter this field.
//
// Default : ""
CertKeyFile string `json:"cert_key_file"`
// CertClientFile is a path of client certificate.
// If you want to use mTLS, you have to enter this field.
//
// Default : ""
CertClientFile string `json:"cert_client_file"`
// When the graceful shutdown begins, use this field to set the timeout
// duration. If the timeout is reached, OnPostShutdown will be called with the error.
// Set to 0 to disable the timeout and wait indefinitely.
//
// Default: 10 * time.Second
ShutdownTimeout time.Duration `json:"shutdown_timeout"`
// FileMode to set for Unix Domain Socket (ListenerNetwork must be "unix")
//
// Default: 0770
UnixSocketFileMode os.FileMode `json:"unix_socket_file_mode"`
// TLSMinVersion allows to set TLS minimum version.
//
// Default: tls.VersionTLS12
// WARNING: TLS1.0 and TLS1.1 versions are not supported.
TLSMinVersion uint16 `json:"tls_min_version"`
// When set to true, it will not print out the «Fiber» ASCII art and listening address.
//
// Default: false
DisableStartupMessage bool `json:"disable_startup_message"`
// When set to true, this will spawn multiple Go processes listening on the same port.
//
// Default: false
EnablePrefork bool `json:"enable_prefork"`
// If set to true, will print all routes with their method, path and handler.
//
// Default: false
EnablePrintRoutes bool `json:"enable_print_routes"`
}
// listenConfigDefault is a function to set default values of ListenConfig.
func listenConfigDefault(config ...ListenConfig) ListenConfig {
if len(config) < 1 {
return ListenConfig{
TLSMinVersion: tls.VersionTLS12,
ListenerNetwork: NetworkTCP4,
UnixSocketFileMode: 0o770,
ShutdownTimeout: 10 * time.Second,
}
}
cfg := config[0]
if cfg.ListenerNetwork == "" {
cfg.ListenerNetwork = NetworkTCP4
}
if cfg.UnixSocketFileMode == 0 {
cfg.UnixSocketFileMode = 0o770
}
if cfg.TLSMinVersion == 0 {
cfg.TLSMinVersion = tls.VersionTLS12
}
if cfg.TLSMinVersion != tls.VersionTLS12 && cfg.TLSMinVersion != tls.VersionTLS13 {
panic("unsupported TLS version, please use tls.VersionTLS12 or tls.VersionTLS13")
}
return cfg
}
// Listen serves HTTP requests from the given addr.
// You should enter custom ListenConfig to customize startup. (TLS, mTLS, prefork...)
//
// app.Listen(":8080")
// app.Listen("127.0.0.1:8080")
// app.Listen(":8080", ListenConfig{EnablePrefork: true})
func (app *App) Listen(addr string, config ...ListenConfig) error {
cfg := listenConfigDefault(config...)
// Configure TLS
var tlsConfig *tls.Config
if cfg.TLSConfig != nil {
tlsConfig = cfg.TLSConfig.Clone()
} else {
switch {
case cfg.AutoCertManager != nil && (cfg.CertFile != "" || cfg.CertKeyFile != ""):
return ErrAutoCertWithCertFile
case cfg.CertFile != "" && cfg.CertKeyFile != "":
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.CertKeyFile)
if err != nil {
return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", cfg.CertFile, cfg.CertKeyFile, err)
}
tlsHandler := &TLSHandler{}
tlsConfig = &tls.Config{ //nolint:gosec // This is a user input
MinVersion: cfg.TLSMinVersion,
Certificates: []tls.Certificate{
cert,
},
GetCertificate: tlsHandler.GetClientInfo,
}
if cfg.CertClientFile != "" {
clientCACert, err := os.ReadFile(filepath.Clean(cfg.CertClientFile))
if err != nil {
return fmt.Errorf("failed to read file: %w", err)
}
clientCertPool := x509.NewCertPool()
clientCertPool.AppendCertsFromPEM(clientCACert)
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = clientCertPool
}
// Attach the tlsHandler to the config
app.SetTLSHandler(tlsHandler)
case cfg.AutoCertManager != nil:
tlsConfig = &tls.Config{ //nolint:gosec // This is a user input
MinVersion: cfg.TLSMinVersion,
GetCertificate: cfg.AutoCertManager.GetCertificate,
NextProtos: []string{"http/1.1", "acme-tls/1"},
}
default:
}
if tlsConfig != nil && cfg.TLSConfigFunc != nil {
cfg.TLSConfigFunc(tlsConfig)
}
}
// Graceful shutdown
if cfg.GracefulContext != nil {
ctx, cancel := context.WithCancel(cfg.GracefulContext)
defer cancel()
go app.gracefulShutdown(ctx, &cfg)
}
// Start prefork
if cfg.EnablePrefork {
return app.prefork(addr, tlsConfig, &cfg)
}
// Configure Listener
ln, err := app.createListener(addr, tlsConfig, &cfg)
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}
// prepare the server for the start
app.startupProcess()
listenData := app.prepareListenData(ln.Addr().String(), getTLSConfig(ln) != nil, &cfg, nil)
// run hooks
app.runOnListenHooks(listenData)
// Print startup message & routes
app.printMessages(&cfg, listenData)
// Serve
if cfg.BeforeServeFunc != nil {
if err := cfg.BeforeServeFunc(app); err != nil {
return err
}
}
return app.server.Serve(ln)
}
// Listener serves HTTP requests from the given listener.
// You should enter custom ListenConfig to customize startup. (prefork, startup message, graceful shutdown...)
func (app *App) Listener(ln net.Listener, config ...ListenConfig) error {
cfg := listenConfigDefault(config...)
// Graceful shutdown
if cfg.GracefulContext != nil {
ctx, cancel := context.WithCancel(cfg.GracefulContext)
defer cancel()
go app.gracefulShutdown(ctx, &cfg)
}
// prepare the server for the start
app.startupProcess()
listenData := app.prepareListenData(ln.Addr().String(), getTLSConfig(ln) != nil, &cfg, nil)
// run hooks
app.runOnListenHooks(listenData)
// Print startup message & routes
app.printMessages(&cfg, listenData)
// Serve
if cfg.BeforeServeFunc != nil {
if err := cfg.BeforeServeFunc(app); err != nil {
return err
}
}
// Prefork is not supported for custom listeners
if cfg.EnablePrefork {
log.Warn("Prefork isn't supported for custom listeners.")
}
return app.server.Serve(ln)
}
// Create listener function.
func (*App) createListener(addr string, tlsConfig *tls.Config, cfg *ListenConfig) (net.Listener, error) {
if cfg == nil {
cfg = &ListenConfig{}
}
var listener net.Listener
var err error
// Remove previously created socket, to make sure it's possible to listen
if cfg.ListenerNetwork == NetworkUnix {
if err = os.Remove(addr); err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("unexpected error when trying to remove unix socket file %q: %w", addr, err)
}
}
if tlsConfig != nil {
listener, err = tls.Listen(cfg.ListenerNetwork, addr, tlsConfig)
} else {
listener, err = net.Listen(cfg.ListenerNetwork, addr)
}
// Check for error before using the listener
if err != nil {
// Wrap the error from tls.Listen/net.Listen
return nil, fmt.Errorf("failed to listen: %w", err)
}
if cfg.ListenerNetwork == NetworkUnix {
if err = os.Chmod(addr, cfg.UnixSocketFileMode); err != nil {
return nil, fmt.Errorf("cannot chmod %#o for %q: %w", cfg.UnixSocketFileMode, addr, err)
}
}
if cfg.ListenerAddrFunc != nil {
cfg.ListenerAddrFunc(listener.Addr())
}
return listener, nil
}
func (app *App) printMessages(cfg *ListenConfig, listenData *ListenData) {
app.startupMessage(listenData, cfg)
if cfg.EnablePrintRoutes {
app.printRoutesMessage()
}
}
// prepareListenData creates a ListenData instance populated with the application metadata.
func (app *App) prepareListenData(addr string, isTLS bool, cfg *ListenConfig, childPIDs []int) *ListenData { //revive:disable-line:flag-parameter // Accepting a bool param named isTLS is fine here
host, port := parseAddr(addr)
if host == "" {
if cfg.ListenerNetwork == NetworkTCP6 {
host = "[::1]"
} else {
host = globalIpv4Addr
}
}
processCount := 1
if cfg.EnablePrefork {
processCount = runtime.GOMAXPROCS(0)
}
var clonedPIDs []int
if len(childPIDs) > 0 {
clonedPIDs = slices.Clone(childPIDs)
}
return &ListenData{
Host: host,
Port: port,
Version: Version,
AppName: app.config.AppName,
ColorScheme: app.config.ColorScheme,
ChildPIDs: clonedPIDs,
HandlerCount: int(app.handlersCount),
ProcessCount: processCount,
PID: os.Getpid(),
TLS: isTLS,
Prefork: cfg.EnablePrefork,
}
}
// startupMessage renders the startup banner using the provided listener metadata and configuration.
func (app *App) startupMessage(listenData *ListenData, cfg *ListenConfig) {
preData := newPreStartupMessageData(listenData)
colors := listenData.ColorScheme
out := colorable.NewColorableStdout()
if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) {
out = colorable.NewNonColorable(os.Stdout)
}
// Add default entries
scheme := schemeHTTP
if listenData.TLS {
scheme = schemeHTTPS
}
if listenData.Host == globalIpv4Addr {
preData.AddInfo("server_address", "Server started on", fmt.Sprintf("%s%s://127.0.0.1:%s%s (bound on host 0.0.0.0 and port %s)",
colors.Blue, scheme, listenData.Port, colors.Reset, listenData.Port), 10)
} else {
preData.AddInfo("server_address", "Server started on", fmt.Sprintf("%s%s://%s:%s%s",
colors.Blue, scheme, listenData.Host, listenData.Port, colors.Reset), 10)
}
if listenData.AppName != "" {
preData.AddInfo("app_name", "Application name", fmt.Sprintf("\t%s%s%s", colors.Blue, listenData.AppName, colors.Reset), 9)
}
preData.AddInfo("total_handlers", "Total handlers", fmt.Sprintf("\t%s%d%s", colors.Blue, listenData.HandlerCount, colors.Reset), 8)
if listenData.Prefork {
preData.AddInfo("prefork", "Prefork", fmt.Sprintf("\t\t%sEnabled%s", colors.Blue, colors.Reset), 7)
} else {
preData.AddInfo("prefork", "Prefork", fmt.Sprintf("\t\t%sDisabled%s", colors.Red, colors.Reset), 6)
}
preData.AddInfo("pid", "PID", fmt.Sprintf("\t\t%s%d%s", colors.Blue, listenData.PID, colors.Reset), 5)
preData.AddInfo("process_count", "Total process count", fmt.Sprintf("%s%d%s", colors.Blue, listenData.ProcessCount, colors.Reset), 4)
if err := app.hooks.executeOnPreStartupMessageHooks(preData); err != nil {
log.Errorf("failed to call pre startup message hook: %v", err)
}
disabled := cfg.DisableStartupMessage
isChild := IsChild()
prevented := preData != nil && preData.PreventDefault
defer func() {
postData := newPostStartupMessageData(listenData, disabled, isChild, prevented)
if err := app.hooks.executeOnPostStartupMessageHooks(postData); err != nil {
log.Errorf("failed to call post startup message hook: %v", err)
}
}()
if preData == nil || disabled || isChild || prevented {
return
}
if preData.BannerHeader != "" {
header := preData.BannerHeader
fmt.Fprint(out, header)
if !strings.HasSuffix(header, "\n") {
fmt.Fprintln(out)
}
} else {
fmt.Fprintf(out, "%s\n", fmt.Sprintf(figletFiberText, colors.Red+"v"+listenData.Version+colors.Reset))
fmt.Fprintln(out, strings.Repeat("-", 50))
}
printStartupEntries(out, &colors, preData.entries)
if err := app.logServices(app.servicesStartupCtx(), out, &colors); err != nil {
log.Errorf("failed to log services: %v", err)
}
if listenData.Prefork && len(listenData.ChildPIDs) > 0 {
fmt.Fprintf(out, "%sINFO%s Child PIDs: \t\t%s", colors.Green, colors.Reset, colors.Blue)
totalPIDs := len(listenData.ChildPIDs)
rowTotalPidCount := 10
for i := 0; i < totalPIDs; i += rowTotalPidCount {
start := i
end := min(i+rowTotalPidCount, totalPIDs)
for idx, pid := range listenData.ChildPIDs[start:end] {
fmt.Fprintf(out, "%d", pid)
if idx+1 != len(listenData.ChildPIDs[start:end]) {
fmt.Fprint(out, ", ")
}
}
fmt.Fprintf(out, "\n%s", colors.Reset)
}
}
fmt.Fprintf(out, "\n%s", colors.Reset)
}
func printStartupEntries(out io.Writer, colors *Colors, entries []startupMessageEntry) {
// Sort entries by priority (higher priority first)
sort.Slice(entries, func(i, j int) bool {
return entries[i].priority > entries[j].priority
})
for _, entry := range entries {
var label string
var color string
switch entry.level {
case StartupMessageLevelWarning:
label, color = "WARN", colors.Yellow
case StartupMessageLevelError:
label, color = errString, colors.Red
default:
label, color = "INFO", colors.Green
}
fmt.Fprintf(out, "%s%s%s %s: \t%s%s%s\n", color, label, colors.Reset, entry.title, colors.Blue, entry.value, colors.Reset)
}
}
// printRoutesMessage print all routes with method, path, name and handlers
// in a format of table, like this:
// method | path | name | handlers
// GET | / | routeName | github.com/gofiber/fiber/v3.emptyHandler
// HEAD | / | | github.com/gofiber/fiber/v3.emptyHandler
func (app *App) printRoutesMessage() {
// ignore child processes
if IsChild() {
return
}
// Alias colors
colors := app.config.ColorScheme
var routes []RouteMessage
for _, routeStack := range app.stack {
for _, route := range routeStack {
var newRoute RouteMessage
newRoute.name = route.Name
newRoute.method = route.Method
newRoute.path = route.Path
for _, handler := range route.Handlers {
newRoute.handlers += runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name() + " "
}
routes = append(routes, newRoute)
}
}
out := colorable.NewColorableStdout()
if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) {
out = colorable.NewNonColorable(os.Stdout)
}
w := tabwriter.NewWriter(out, 1, 1, 1, ' ', 0)
// Sort routes by path
sort.Slice(routes, func(i, j int) bool {
return routes[i].path < routes[j].path
})
fmt.Fprintf(w, "%smethod\t%s| %spath\t%s| %sname\t%s| %shandlers\t%s\n", colors.Blue, colors.White, colors.Green, colors.White, colors.Cyan, colors.White, colors.Yellow, colors.Reset)
fmt.Fprintf(w, "%s------\t%s| %s----\t%s| %s----\t%s| %s--------\t%s\n", colors.Blue, colors.White, colors.Green, colors.White, colors.Cyan, colors.White, colors.Yellow, colors.Reset)
for _, route := range routes {
fmt.Fprintf(w, "%s%s\t%s| %s%s\t%s| %s%s\t%s| %s%s%s\n", colors.Blue, route.method, colors.White, colors.Green, route.path, colors.White, colors.Cyan, route.name, colors.White, colors.Yellow, route.handlers, colors.Reset)
}
_ = w.Flush() //nolint:errcheck // It is fine to ignore the error here
}
// shutdown goroutine
func (app *App) gracefulShutdown(ctx context.Context, cfg *ListenConfig) {
<-ctx.Done()
var err error
if cfg != nil && cfg.ShutdownTimeout != 0 {
err = app.ShutdownWithTimeout(cfg.ShutdownTimeout) //nolint:contextcheck // TODO: Implement it
} else {
err = app.Shutdown() //nolint:contextcheck // TODO: Implement it
}
if err != nil {
app.hooks.executeOnPostShutdownHooks(err)
return
}
app.hooks.executeOnPostShutdownHooks(nil)
}
================================================
FILE: listen_test.go
================================================
package fiber
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log" //nolint:depguard // TODO: Required to capture output, use internal log package instead
"net"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/gofiber/utils/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
"golang.org/x/crypto/acme/autocert"
)
// go test -run Test_Listen
func Test_Listen(t *testing.T) {
app := New()
require.Error(t, app.Listen(":99999"))
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{DisableStartupMessage: true}))
}
// go test -run Test_Listen_Graceful_Shutdown
func Test_Listen_Graceful_Shutdown(t *testing.T) {
t.Run("Basic Graceful Shutdown", func(t *testing.T) {
testGracefulShutdown(t, 0)
})
t.Run("Shutdown With Timeout", func(t *testing.T) {
testGracefulShutdown(t, 500*time.Millisecond)
})
t.Run("Shutdown With Timeout Error", func(t *testing.T) {
testGracefulShutdown(t, 1*time.Nanosecond)
})
}
func testGracefulShutdown(t *testing.T, shutdownTimeout time.Duration) {
t.Helper()
var mu sync.Mutex
var shutdown bool
var receivedErr error
app := New()
app.Get("/", func(c Ctx) error {
time.Sleep(10 * time.Millisecond)
return c.SendString(c.Hostname())
})
ln := fasthttputil.NewInmemoryListener()
errs := make(chan error, 1)
app.hooks.OnPostShutdown(func(err error) error {
mu.Lock()
defer mu.Unlock()
shutdown = true
receivedErr = err
return nil
})
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
errs <- app.Listener(ln, ListenConfig{
DisableStartupMessage: true,
GracefulContext: ctx,
ShutdownTimeout: shutdownTimeout,
})
}()
require.Eventually(t, func() bool {
conn, err := ln.Dial()
if err == nil {
if err := conn.Close(); err != nil {
t.Logf("error closing connection: %v", err)
}
return true
}
return false
}, time.Second, 100*time.Millisecond, "Server failed to become ready")
client := fasthttp.HostClient{
Dial: func(_ string) (net.Conn, error) { return ln.Dial() },
}
type testCase struct {
expectedErr error
expectedBody string
name string
waitTime time.Duration
expectedStatusCode int
closeConnection bool
}
testCases := []testCase{
{
name: "Server running normally",
waitTime: 500 * time.Millisecond,
expectedBody: "example.com",
expectedStatusCode: StatusOK,
expectedErr: nil,
closeConnection: true,
},
{
name: "Server shutdown complete",
waitTime: 3 * time.Second,
expectedBody: "",
expectedStatusCode: StatusOK,
expectedErr: fasthttputil.ErrInmemoryListenerClosed,
closeConnection: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
time.Sleep(tc.waitTime)
req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)
req.SetRequestURI("http://example.com")
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
err := client.Do(req, resp)
if tc.expectedErr == nil {
require.NoError(t, err)
require.Equal(t, tc.expectedStatusCode, resp.StatusCode())
require.Equal(t, tc.expectedBody, utils.UnsafeString(resp.Body()))
} else {
require.ErrorIs(t, err, tc.expectedErr)
}
})
}
mu.Lock()
require.True(t, shutdown)
if shutdownTimeout == 1*time.Nanosecond {
require.Error(t, receivedErr)
require.ErrorIs(t, receivedErr, context.DeadlineExceeded)
}
require.NoError(t, <-errs)
mu.Unlock()
}
// go test -run Test_Listen_Prefork
func Test_Listen_Prefork(t *testing.T) {
testPreforkMaster = true
app := New()
require.NoError(t, app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true}))
}
// go test -run Test_Listen_TLSMinVersion
func Test_Listen_TLSMinVersion(t *testing.T) {
testPreforkMaster = true
app := New()
// Invalid TLSMinVersion
require.Panics(t, func() {
_ = app.Listen(":0", ListenConfig{TLSMinVersion: tls.VersionTLS10}) //nolint:errcheck // ignore error
})
require.Panics(t, func() {
_ = app.Listen(":0", ListenConfig{TLSMinVersion: tls.VersionTLS11}) //nolint:errcheck // ignore error
})
// Prefork
require.Panics(t, func() {
_ = app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true, TLSMinVersion: tls.VersionTLS10}) //nolint:errcheck // ignore error
})
require.Panics(t, func() {
_ = app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true, TLSMinVersion: tls.VersionTLS11}) //nolint:errcheck // ignore error
})
// Valid TLSMinVersion
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{TLSMinVersion: tls.VersionTLS13}))
// Valid TLSMinVersion with Prefork
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true, TLSMinVersion: tls.VersionTLS13}))
}
// go test -run Test_Listen_TLS
func Test_Listen_TLS(t *testing.T) {
app := New()
// invalid port
require.Error(t, app.Listen(":99999", ListenConfig{
CertFile: "./.github/testdata/ssl.pem",
CertKeyFile: "./.github/testdata/ssl.key",
}))
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{
CertFile: "./.github/testdata/ssl.pem",
CertKeyFile: "./.github/testdata/ssl.key",
}))
}
// go test -run Test_Listen_TLS_Prefork
func Test_Listen_TLS_Prefork(t *testing.T) {
testPreforkMaster = true
app := New()
// invalid key file content
require.Error(t, app.Listen(":0", ListenConfig{
DisableStartupMessage: true,
EnablePrefork: true,
CertFile: "./.github/testdata/ssl.pem",
CertKeyFile: "./.github/testdata/template.tmpl",
}))
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{
DisableStartupMessage: true,
EnablePrefork: true,
CertFile: "./.github/testdata/ssl.pem",
CertKeyFile: "./.github/testdata/ssl.key",
}))
}
// go test -run Test_Listen_MutualTLS
func Test_Listen_MutualTLS(t *testing.T) {
app := New()
// invalid port
require.Error(t, app.Listen(":99999", ListenConfig{
CertFile: "./.github/testdata/ssl.pem",
CertKeyFile: "./.github/testdata/ssl.key",
CertClientFile: "./.github/testdata/ca-chain.cert.pem",
}))
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{
CertFile: "./.github/testdata/ssl.pem",
CertKeyFile: "./.github/testdata/ssl.key",
CertClientFile: "./.github/testdata/ca-chain.cert.pem",
}))
}
// go test -run Test_Listen_MutualTLS_Prefork
func Test_Listen_MutualTLS_Prefork(t *testing.T) {
testPreforkMaster = true
app := New()
// invalid key file content
require.Error(t, app.Listen(":0", ListenConfig{
DisableStartupMessage: true,
EnablePrefork: true,
CertFile: "./.github/testdata/ssl.pem",
CertKeyFile: "./.github/testdata/template.html",
CertClientFile: "./.github/testdata/ca-chain.cert.pem",
}))
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{
DisableStartupMessage: true,
EnablePrefork: true,
CertFile: "./.github/testdata/ssl.pem",
CertKeyFile: "./.github/testdata/ssl.key",
CertClientFile: "./.github/testdata/ca-chain.cert.pem",
}))
}
// go test -run Test_Listener
func Test_Listener(t *testing.T) {
app := New()
go func() {
time.Sleep(500 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
ln := fasthttputil.NewInmemoryListener()
require.NoError(t, app.Listener(ln))
}
func Test_App_Listener_TLS_Listener(t *testing.T) {
// Create tls certificate
cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")
if err != nil {
require.NoError(t, err)
}
//nolint:gosec // We're in a test so using old ciphers is fine
config := &tls.Config{Certificates: []tls.Certificate{cer}}
//nolint:gosec // We're in a test so listening on all interfaces is fine
ln, err := tls.Listen(NetworkTCP4, ":0", config)
require.NoError(t, err)
app := New()
go func() {
time.Sleep(time.Millisecond * 500)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listener(ln))
}
// go test -run Test_Listen_TLSConfigFunc
func Test_Listen_TLSConfigFunc(t *testing.T) {
var callTLSConfig bool
app := New()
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{
DisableStartupMessage: true,
TLSConfigFunc: func(_ *tls.Config) {
callTLSConfig = true
},
CertFile: "./.github/testdata/ssl.pem",
CertKeyFile: "./.github/testdata/ssl.key",
}))
require.True(t, callTLSConfig)
}
// go test -run Test_Listen_TLSConfig
func Test_Listen_TLSConfig(t *testing.T) {
t.Parallel()
cert, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")
require.NoError(t, err)
run := func(name string, cfg ListenConfig) {
t.Run(name, func(t *testing.T) {
t.Parallel()
app := New()
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", cfg))
})
}
run("TLSConfig with certificates", ListenConfig{
DisableStartupMessage: true,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
},
})
run("TLSConfig with GetCertificate", ListenConfig{
DisableStartupMessage: true,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return &cert, nil
},
},
})
run("TLSConfig ignores other TLS fields", ListenConfig{
DisableStartupMessage: true,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
},
CertFile: "./.github/testdata/does-not-exist.pem",
CertKeyFile: "./.github/testdata/does-not-exist.key",
CertClientFile: "./.github/testdata/does-not-exist-ca.pem",
AutoCertManager: &autocert.Manager{
Prompt: autocert.AcceptTOS,
},
})
}
// go test -run Test_Listen_TLSCertFiles
func Test_Listen_TLSCertFiles(t *testing.T) {
t.Parallel()
app := New()
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{
DisableStartupMessage: true,
CertFile: "./.github/testdata/ssl.pem",
CertKeyFile: "./.github/testdata/ssl.key",
CertClientFile: "./.github/testdata/ssl.pem",
}))
}
// go test -run Test_Listen_TLSConfig_WithTLSConfigFunc
func Test_Listen_TLSConfig_WithTLSConfigFunc(t *testing.T) {
t.Parallel()
cert, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")
require.NoError(t, err)
var calledTLSConfigFunc bool
app := New()
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{
DisableStartupMessage: true,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
},
TLSConfigFunc: func(_ *tls.Config) {
calledTLSConfigFunc = true
},
}))
require.False(t, calledTLSConfigFunc)
}
// go test -run Test_Listen_AutoCert_Conflicts
func Test_Listen_AutoCert_Conflicts(t *testing.T) {
t.Parallel()
app := New()
err := app.Listen(":0", ListenConfig{
AutoCertManager: &autocert.Manager{},
CertFile: "./.github/testdata/ssl.pem",
CertKeyFile: "./.github/testdata/ssl.key",
})
require.ErrorIs(t, err, ErrAutoCertWithCertFile)
}
// go test -run Test_Listen_ListenerAddrFunc
func Test_Listen_ListenerAddrFunc(t *testing.T) {
var network string
app := New()
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{
DisableStartupMessage: true,
ListenerAddrFunc: func(addr net.Addr) {
network = addr.Network()
},
CertFile: "./.github/testdata/ssl.pem",
CertKeyFile: "./.github/testdata/ssl.key",
}))
require.Equal(t, "tcp", network)
}
// go test -run Test_Listen_BeforeServeFunc
func Test_Listen_BeforeServeFunc(t *testing.T) {
var handlers uint32
app := New()
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
wantErr := errors.New("test")
require.ErrorIs(t, app.Listen(":0", ListenConfig{
DisableStartupMessage: true,
BeforeServeFunc: func(fiber *App) error {
handlers = fiber.HandlersCount()
return wantErr
},
}), wantErr)
require.Zero(t, handlers)
}
// go test -run Test_Listen_ListenerNetwork
func Test_Listen_ListenerNetwork(t *testing.T) {
var network string
app := New()
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{
DisableStartupMessage: true,
ListenerNetwork: NetworkTCP6,
ListenerAddrFunc: func(addr net.Addr) {
network = addr.String()
},
}))
require.Contains(t, network, "[::]:")
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(":0", ListenConfig{
DisableStartupMessage: true,
ListenerNetwork: NetworkTCP4,
ListenerAddrFunc: func(addr net.Addr) {
network = addr.String()
},
}))
require.Contains(t, network, "0.0.0.0:")
}
// go test -run Test_Listen_ListenerNetwork_Unix
func Test_Listen_ListenerNetwork_Unix(t *testing.T) {
app := New()
app.Get("/test", func(c Ctx) error {
return c.SendString("all good")
})
var (
f os.FileInfo
network string
reqErr error
resp = &fasthttp.Response{}
)
// Create temporary directory for storing socket in
tmp, err := os.MkdirTemp(os.TempDir(), "fiber-test")
require.NoError(t, err)
sock := filepath.Join(tmp, "fiber-test.sock")
// Make sure temporary directory is cleaned up
defer func() { assert.NoError(t, os.RemoveAll(tmp)) }()
// Send request through socket
go func() {
time.Sleep(1000 * time.Millisecond)
client := &fasthttp.HostClient{
Addr: sock,
Dial: func(addr string) (net.Conn, error) {
return net.Dial("unix", addr)
},
}
req := &fasthttp.Request{}
req.SetRequestURI("http://host/test")
reqErr = client.Do(req, resp)
assert.NoError(t, app.Shutdown())
}()
require.NoError(t, app.Listen(sock, ListenConfig{
DisableStartupMessage: true,
ListenerNetwork: NetworkUnix,
UnixSocketFileMode: 0o666,
ListenerAddrFunc: func(addr net.Addr) {
network = addr.String()
f, err = os.Stat(network)
},
}))
// Verify that listening and setting permissions works correctly
require.Equal(t, sock, network)
require.NoError(t, err)
require.Equal(t, os.FileMode(0o666), f.Mode().Perm())
// Verify that request was successful
require.NoError(t, reqErr)
require.Equal(t, 200, resp.StatusCode())
require.Equal(t, "all good", string(resp.Body()))
}
// go test -run Test_Listen_Master_Process_Show_Startup_Message
func Test_Listen_Master_Process_Show_Startup_Message(t *testing.T) {
cfg := ListenConfig{
EnablePrefork: true,
}
ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
addr, ok := ln.Addr().(*net.TCPAddr)
require.True(t, ok)
port := addr.Port
require.NoError(t, ln.Close())
childTemplate := []int{11111, 22222, 33333, 44444, 55555, 60000}
childPIDs := make([]int, 0, len(childTemplate)*10)
for range 10 {
childPIDs = append(childPIDs, childTemplate...)
}
app := New()
listenData := app.prepareListenData(fmt.Sprintf(":%d", port), true, &cfg, childPIDs)
startupMessage := captureOutput(func() {
app.startupMessage(listenData, &cfg)
})
colors := Colors{}
require.Contains(t, startupMessage, fmt.Sprintf("https://127.0.0.1:%d", port))
require.Contains(t, startupMessage, fmt.Sprintf("(bound on host 0.0.0.0 and port %d)", port))
require.Contains(t, startupMessage, "Child PIDs")
require.Contains(t, startupMessage, "11111, 22222, 33333, 44444, 55555, 60000")
require.Contains(t, startupMessage, fmt.Sprintf("Prefork: \t\t\t%sEnabled%s", colors.Blue, colors.Reset))
}
// go test -run Test_Listen_Master_Process_Show_Startup_MessageWithAppName
func Test_Listen_Master_Process_Show_Startup_MessageWithAppName(t *testing.T) {
cfg := ListenConfig{
EnablePrefork: true,
}
app := New(Config{AppName: "Test App v3.0.0"})
ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
addr, ok := ln.Addr().(*net.TCPAddr)
require.True(t, ok)
port := addr.Port
require.NoError(t, ln.Close())
childTemplate := []int{11111, 22222, 33333, 44444, 55555, 60000}
childPIDs := make([]int, 0, len(childTemplate)*10)
for range 10 {
childPIDs = append(childPIDs, childTemplate...)
}
listenData := app.prepareListenData(fmt.Sprintf(":%d", port), true, &cfg, childPIDs)
startupMessage := captureOutput(func() {
app.startupMessage(listenData, &cfg)
})
require.Equal(t, "Test App v3.0.0", app.Config().AppName)
require.Contains(t, startupMessage, app.Config().AppName)
}
// go test -run Test_Listen_Master_Process_Show_Startup_MessageWithAppNameNonAscii
func Test_Listen_Master_Process_Show_Startup_MessageWithAppNameNonAscii(t *testing.T) {
cfg := ListenConfig{
EnablePrefork: true,
}
appName := "Serveur de vérification des données"
app := New(Config{AppName: appName})
ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
addr, ok := ln.Addr().(*net.TCPAddr)
require.True(t, ok)
port := addr.Port
require.NoError(t, ln.Close())
listenData := app.prepareListenData(fmt.Sprintf(":%d", port), false, &cfg, nil)
startupMessage := captureOutput(func() {
app.startupMessage(listenData, &cfg)
})
require.Contains(t, startupMessage, "Serveur de vérification des données")
}
// go test -run Test_Listen_Master_Process_Show_Startup_MessageWithDisabledPreforkAndCustomEndpoint
func Test_Listen_Master_Process_Show_Startup_MessageWithDisabledPreforkAndCustomEndpoint(t *testing.T) {
cfg := ListenConfig{
EnablePrefork: false,
}
appName := "Fiber Example Application"
app := New(Config{AppName: appName})
ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
addr, ok := ln.Addr().(*net.TCPAddr)
require.True(t, ok)
port := addr.Port
require.NoError(t, ln.Close())
listenData := app.prepareListenData(fmt.Sprintf("server.com:%d", port), true, &cfg, nil)
startupMessage := captureOutput(func() {
app.startupMessage(listenData, &cfg)
})
colors := Colors{}
require.Contains(t, startupMessage, fmt.Sprintf("%sINFO%s", colors.Green, colors.Reset))
require.Contains(t, startupMessage, fmt.Sprintf("%s%s%s", colors.Blue, appName, colors.Reset))
expectedURL := fmt.Sprintf("https://server.com:%d", port)
require.Contains(t, startupMessage, fmt.Sprintf("%s%s%s", colors.Blue, expectedURL, colors.Reset))
require.Contains(t, startupMessage, fmt.Sprintf("Prefork: \t\t\t%sDisabled%s", colors.Red, colors.Reset))
}
func Test_StartupMessageCustomization(t *testing.T) {
cfg := ListenConfig{}
app := New()
listenData := app.prepareListenData(":8080", false, &cfg, nil)
app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error {
data.BannerHeader = "FOOBER v98\n-------"
data.ResetEntries()
data.AddInfo("git_hash", "Git hash", "abc123", 3)
data.AddInfo("version", "Version", "v98", 2)
return nil
})
var post PostStartupMessageData
app.Hooks().OnPostStartupMessage(func(data *PostStartupMessageData) error {
post = *data
return nil
})
startupMessage := captureOutput(func() {
app.startupMessage(listenData, &cfg)
})
require.Contains(t, startupMessage, "FOOBER v98")
require.Contains(t, startupMessage, "Git hash: \tabc123")
require.Contains(t, startupMessage, "Version: \tv98")
require.NotContains(t, startupMessage, "Server started on:")
require.NotContains(t, startupMessage, "Prefork:")
require.False(t, post.Disabled)
require.False(t, post.IsChild)
require.False(t, post.Prevented)
}
func Test_StartupMessageDisabledPostHook(t *testing.T) {
cfg := ListenConfig{DisableStartupMessage: true}
app := New()
listenData := app.prepareListenData(":7070", false, &cfg, nil)
var post PostStartupMessageData
app.Hooks().OnPostStartupMessage(func(data *PostStartupMessageData) error {
post = *data
return nil
})
startupMessage := captureOutput(func() {
app.startupMessage(listenData, &cfg)
})
require.Empty(t, startupMessage)
require.True(t, post.Disabled)
require.False(t, post.IsChild)
require.False(t, post.Prevented)
}
func Test_StartupMessagePreventedByHook(t *testing.T) {
cfg := ListenConfig{}
app := New()
listenData := app.prepareListenData(":9090", false, &cfg, nil)
app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error {
data.PreventDefault = true
return nil
})
var post PostStartupMessageData
app.Hooks().OnPostStartupMessage(func(data *PostStartupMessageData) error {
post = *data
return nil
})
startupMessage := captureOutput(func() {
app.startupMessage(listenData, &cfg)
})
require.Empty(t, startupMessage)
require.False(t, post.Disabled)
require.False(t, post.IsChild)
require.True(t, post.Prevented)
}
// go test -run Test_Listen_Print_Route
func Test_Listen_Print_Route(t *testing.T) {
app := New()
app.Get("/", emptyHandler).Name("routeName")
printRoutesMessage := captureOutput(func() {
app.printRoutesMessage()
})
require.Contains(t, printRoutesMessage, MethodGet)
require.Contains(t, printRoutesMessage, "/")
require.Contains(t, printRoutesMessage, "emptyHandler")
require.Contains(t, printRoutesMessage, "routeName")
}
// go test -run Test_Listen_Print_Route_With_Group
func Test_Listen_Print_Route_With_Group(t *testing.T) {
app := New()
app.Get("/", emptyHandler)
v1 := app.Group("v1")
v1.Get("/test", emptyHandler).Name("v1")
v1.Post("/test/fiber", emptyHandler)
v1.Put("/test/fiber/*", emptyHandler)
printRoutesMessage := captureOutput(func() {
app.printRoutesMessage()
})
require.Contains(t, printRoutesMessage, MethodGet)
require.Contains(t, printRoutesMessage, "/")
require.Contains(t, printRoutesMessage, "emptyHandler")
require.Contains(t, printRoutesMessage, "/v1/test")
require.Contains(t, printRoutesMessage, "POST")
require.Contains(t, printRoutesMessage, "/v1/test/fiber")
require.Contains(t, printRoutesMessage, "PUT")
require.Contains(t, printRoutesMessage, "/v1/test/fiber/*")
}
func captureOutput(f func()) string {
reader, writer, err := os.Pipe()
if err != nil {
panic(err)
}
stdout := os.Stdout
stderr := os.Stderr
defer func() {
os.Stdout = stdout
os.Stderr = stderr
log.SetOutput(os.Stderr)
}()
os.Stdout = writer
os.Stderr = writer
log.SetOutput(writer)
out := make(chan string)
go func() {
var buf bytes.Buffer
_, copyErr := io.Copy(&buf, reader)
if copyErr != nil {
panic(copyErr)
}
out <- buf.String() // this out channel helps in synchronization
}()
f()
err = writer.Close()
if err != nil {
panic(err)
}
return <-out
}
func emptyHandler(_ Ctx) error {
return nil
}
================================================
FILE: log/default.go
================================================
package log
import (
"context"
"fmt"
"io"
"log"
"os"
"github.com/gofiber/utils/v2"
"github.com/valyala/bytebufferpool"
)
var _ AllLogger[*log.Logger] = (*defaultLogger)(nil)
type defaultLogger struct {
stdlog *log.Logger
level Level
depth int
}
// privateLog logs a message at a given level log the default logger.
// when the level is fatal, it will exit the program.
func (l *defaultLogger) privateLog(lv Level, fmtArgs []any) {
if l.level > lv {
return
}
level := lv.toString()
buf := bytebufferpool.Get()
buf.WriteString(level)
fmt.Fprint(buf, fmtArgs...)
_ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error
if lv == LevelPanic {
panic(buf.String())
}
buf.Reset()
bytebufferpool.Put(buf)
if lv == LevelFatal {
os.Exit(1) //nolint:revive // we want to exit the program when Fatal is called
}
}
// privateLogf logs a formatted message at a given level log the default logger.
// when the level is fatal, it will exit the program.
func (l *defaultLogger) privateLogf(lv Level, format string, fmtArgs []any) {
if l.level > lv {
return
}
level := lv.toString()
buf := bytebufferpool.Get()
buf.WriteString(level)
if len(fmtArgs) > 0 {
_, _ = fmt.Fprintf(buf, format, fmtArgs...)
} else {
_, _ = fmt.Fprint(buf, format)
}
_ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error
if lv == LevelPanic {
panic(buf.String())
}
buf.Reset()
bytebufferpool.Put(buf)
if lv == LevelFatal {
os.Exit(1) //nolint:revive // we want to exit the program when Fatal is called
}
}
// privateLogw logs a message at a given level log the default logger.
// when the level is fatal, it will exit the program.
func (l *defaultLogger) privateLogw(lv Level, format string, keysAndValues []any) {
if l.level > lv {
return
}
level := lv.toString()
buf := bytebufferpool.Get()
buf.WriteString(level)
// Write format privateLog buffer
if format != "" {
buf.WriteString(format)
}
// Write keys and values privateLog buffer
if len(keysAndValues) > 0 {
if (len(keysAndValues) & 1) == 1 {
keysAndValues = append(keysAndValues, "KEYVALS UNPAIRED")
}
for i := 0; i < len(keysAndValues); i += 2 {
if i > 0 || format != "" {
buf.WriteByte(' ')
}
switch key := keysAndValues[i].(type) {
case string:
buf.WriteString(key)
default:
_, _ = fmt.Fprint(buf, key)
}
buf.WriteByte('=')
buf.WriteString(utils.ToString(keysAndValues[i+1]))
}
}
_ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error
if lv == LevelPanic {
panic(buf.String())
}
buf.Reset()
bytebufferpool.Put(buf)
if lv == LevelFatal {
os.Exit(1) //nolint:revive // we want to exit the program when Fatal is called
}
}
// Trace logs the given values at trace level.
func (l *defaultLogger) Trace(v ...any) {
l.privateLog(LevelTrace, v)
}
// Debug logs the given values at debug level.
func (l *defaultLogger) Debug(v ...any) {
l.privateLog(LevelDebug, v)
}
// Info logs the given values at info level.
func (l *defaultLogger) Info(v ...any) {
l.privateLog(LevelInfo, v)
}
// Warn logs the given values at warn level.
func (l *defaultLogger) Warn(v ...any) {
l.privateLog(LevelWarn, v)
}
// Error logs the given values at error level.
func (l *defaultLogger) Error(v ...any) {
l.privateLog(LevelError, v)
}
// Fatal logs the given values at fatal level and terminates the process.
func (l *defaultLogger) Fatal(v ...any) {
l.privateLog(LevelFatal, v)
}
// Panic logs the given values at panic level and panics.
func (l *defaultLogger) Panic(v ...any) {
l.privateLog(LevelPanic, v)
}
// Tracef formats according to a format specifier and logs at trace level.
func (l *defaultLogger) Tracef(format string, v ...any) {
l.privateLogf(LevelTrace, format, v)
}
// Debugf formats according to a format specifier and logs at debug level.
func (l *defaultLogger) Debugf(format string, v ...any) {
l.privateLogf(LevelDebug, format, v)
}
// Infof formats according to a format specifier and logs at info level.
func (l *defaultLogger) Infof(format string, v ...any) {
l.privateLogf(LevelInfo, format, v)
}
// Warnf formats according to a format specifier and logs at warn level.
func (l *defaultLogger) Warnf(format string, v ...any) {
l.privateLogf(LevelWarn, format, v)
}
// Errorf formats according to a format specifier and logs at error level.
func (l *defaultLogger) Errorf(format string, v ...any) {
l.privateLogf(LevelError, format, v)
}
// Fatalf formats according to a format specifier, logs at fatal level, and terminates the process.
func (l *defaultLogger) Fatalf(format string, v ...any) {
l.privateLogf(LevelFatal, format, v)
}
// Panicf formats according to a format specifier, logs at panic level, and panics.
func (l *defaultLogger) Panicf(format string, v ...any) {
l.privateLogf(LevelPanic, format, v)
}
// Tracew logs at trace level with a message and key/value pairs.
func (l *defaultLogger) Tracew(msg string, keysAndValues ...any) {
l.privateLogw(LevelTrace, msg, keysAndValues)
}
// Debugw logs at debug level with a message and key/value pairs.
func (l *defaultLogger) Debugw(msg string, keysAndValues ...any) {
l.privateLogw(LevelDebug, msg, keysAndValues)
}
// Infow logs at info level with a message and key/value pairs.
func (l *defaultLogger) Infow(msg string, keysAndValues ...any) {
l.privateLogw(LevelInfo, msg, keysAndValues)
}
// Warnw logs at warn level with a message and key/value pairs.
func (l *defaultLogger) Warnw(msg string, keysAndValues ...any) {
l.privateLogw(LevelWarn, msg, keysAndValues)
}
// Errorw logs at error level with a message and key/value pairs.
func (l *defaultLogger) Errorw(msg string, keysAndValues ...any) {
l.privateLogw(LevelError, msg, keysAndValues)
}
// Fatalw logs at fatal level with a message and key/value pairs, then terminates the process.
func (l *defaultLogger) Fatalw(msg string, keysAndValues ...any) {
l.privateLogw(LevelFatal, msg, keysAndValues)
}
// Panicw logs at panic level with a message and key/value pairs, then panics.
func (l *defaultLogger) Panicw(msg string, keysAndValues ...any) {
l.privateLogw(LevelPanic, msg, keysAndValues)
}
// WithContext returns a logger that shares the underlying output but adjusts the call depth.
func (l *defaultLogger) WithContext(_ context.Context) CommonLogger {
return &defaultLogger{
stdlog: l.stdlog,
level: l.level,
depth: l.depth - 1,
}
}
// SetLevel updates the minimum level that will be emitted by the logger.
func (l *defaultLogger) SetLevel(level Level) {
l.level = level
}
// SetOutput replaces the underlying writer used by the logger.
func (l *defaultLogger) SetOutput(writer io.Writer) {
l.stdlog.SetOutput(writer)
}
// Logger returns the logger instance. It can be used to adjust the logger configurations in case of need.
func (l *defaultLogger) Logger() *log.Logger {
return l.stdlog
}
// DefaultLogger returns the default logger.
func DefaultLogger[T any]() AllLogger[T] {
if l, ok := logger.(AllLogger[T]); ok {
return l
}
return nil
}
================================================
FILE: log/default_test.go
================================================
package log
import (
"bytes"
"context"
"log"
"os"
"testing"
"github.com/stretchr/testify/require"
)
const work = "work"
func initDefaultLogger() {
logger = &defaultLogger{
stdlog: log.New(os.Stderr, "", 0),
depth: 4,
}
}
type byteSliceWriter struct {
b []byte
}
func (w *byteSliceWriter) Write(p []byte) (int, error) {
w.b = append(w.b, p...)
return len(p), nil
}
func Test_WithContextCaller(t *testing.T) {
logger = &defaultLogger{
stdlog: log.New(os.Stderr, "", log.Lshortfile),
depth: 4,
}
var w byteSliceWriter
SetOutput(&w)
ctx := context.TODO()
WithContext(ctx).Info("")
Info("")
require.Equal(t, "default_test.go:41: [Info] \ndefault_test.go:42: [Info] \n", string(w.b))
}
func Test_DefaultLogger(t *testing.T) {
initDefaultLogger()
var w byteSliceWriter
SetOutput(&w)
Trace("trace work")
Debug("received work order")
Info("starting work")
Warn("work may fail")
Error("work failed")
require.Panics(t, func() {
Panic("work panic")
})
require.Equal(t, "[Trace] trace work\n"+
"[Debug] received work order\n"+
"[Info] starting work\n"+
"[Warn] work may fail\n"+
"[Error] work failed\n"+
"[Panic] work panic\n", string(w.b))
}
func Test_DefaultFormatLogger(t *testing.T) {
initDefaultLogger()
var w byteSliceWriter
SetOutput(&w)
Tracef("trace %s", work)
Debugf("received %s order", work)
Infof("starting %s", work)
Warnf("%s may fail", work)
Errorf("%s failed", work)
require.Panics(t, func() {
Panicf("%s panic", work)
})
require.Equal(t, "[Trace] trace work\n"+
"[Debug] received work order\n"+
"[Info] starting work\n"+
"[Warn] work may fail\n"+
"[Error] work failed\n"+
"[Panic] work panic\n", string(w.b))
}
func Test_CtxLogger(t *testing.T) {
initDefaultLogger()
var w byteSliceWriter
SetOutput(&w)
ctx := context.Background()
WithContext(ctx).Tracef("trace %s", work)
WithContext(ctx).Debugf("received %s order", work)
WithContext(ctx).Infof("starting %s", work)
WithContext(ctx).Warnf("%s may fail", work)
WithContext(ctx).Errorf("%s failed %d", work, 50)
require.Panics(t, func() {
WithContext(ctx).Panicf("%s panic", work)
})
require.Equal(t, "[Trace] trace work\n"+
"[Debug] received work order\n"+
"[Info] starting work\n"+
"[Warn] work may fail\n"+
"[Error] work failed 50\n"+
"[Panic] work panic\n", string(w.b))
}
func Test_LogfKeyAndValues(t *testing.T) {
tests := []struct {
name string
format string
wantOutput string
fmtArgs []any
keysAndValues []any
level Level
}{
{
name: "test logf with debug level and key-values",
level: LevelDebug,
format: "",
fmtArgs: nil,
keysAndValues: []any{"name", "Bob", "age", 30},
wantOutput: "[Debug] name=Bob age=30\n",
},
{
name: "test logf with info level and key-values",
level: LevelInfo,
format: "",
fmtArgs: nil,
keysAndValues: []any{"status", "ok", "code", 200},
wantOutput: "[Info] status=ok code=200\n",
},
{
name: "test logf with warn level and key-values",
level: LevelWarn,
format: "",
fmtArgs: nil,
keysAndValues: []any{"error", "not found", "id", 123},
wantOutput: "[Warn] error=not found id=123\n",
},
{
name: "test logf with format and key-values",
level: LevelWarn,
format: "test",
fmtArgs: nil,
keysAndValues: []any{"error", "not found", "id", 123},
wantOutput: "[Warn] test error=not found id=123\n",
},
{
name: "test logf with one key",
level: LevelWarn,
format: "",
fmtArgs: nil,
keysAndValues: []any{"error"},
wantOutput: "[Warn] error=KEYVALS UNPAIRED\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
l := &defaultLogger{
stdlog: log.New(&buf, "", 0),
level: tt.level,
depth: 4,
}
l.privateLogw(tt.level, tt.format, tt.keysAndValues)
require.Equal(t, tt.wantOutput, buf.String())
})
}
}
func Test_SetLevel(t *testing.T) {
setLogger := &defaultLogger{
stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds),
depth: 4,
}
setLogger.SetLevel(LevelTrace)
require.Equal(t, LevelTrace, setLogger.level)
require.Equal(t, LevelTrace.toString(), setLogger.level.toString())
setLogger.SetLevel(LevelDebug)
require.Equal(t, LevelDebug, setLogger.level)
require.Equal(t, LevelDebug.toString(), setLogger.level.toString())
setLogger.SetLevel(LevelInfo)
require.Equal(t, LevelInfo, setLogger.level)
require.Equal(t, LevelInfo.toString(), setLogger.level.toString())
setLogger.SetLevel(LevelWarn)
require.Equal(t, LevelWarn, setLogger.level)
require.Equal(t, LevelWarn.toString(), setLogger.level.toString())
setLogger.SetLevel(LevelError)
require.Equal(t, LevelError, setLogger.level)
require.Equal(t, LevelError.toString(), setLogger.level.toString())
setLogger.SetLevel(LevelFatal)
require.Equal(t, LevelFatal, setLogger.level)
require.Equal(t, LevelFatal.toString(), setLogger.level.toString())
setLogger.SetLevel(LevelPanic)
require.Equal(t, LevelPanic, setLogger.level)
require.Equal(t, LevelPanic.toString(), setLogger.level.toString())
setLogger.SetLevel(8)
require.Equal(t, 8, int(setLogger.level))
require.Equal(t, "[?8] ", setLogger.level.toString())
}
func Test_Logger(t *testing.T) {
underlyingLogger := log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds)
setLogger := &defaultLogger{
stdlog: underlyingLogger,
depth: 4,
}
require.Equal(t, underlyingLogger, setLogger.Logger())
logger := setLogger.Logger()
logger.SetFlags(log.LstdFlags | log.Lshortfile | log.Lmicroseconds)
require.Equal(t, log.LstdFlags|log.Lshortfile|log.Lmicroseconds, setLogger.stdlog.Flags())
}
func Test_Debugw(t *testing.T) {
initDefaultLogger()
var w byteSliceWriter
SetOutput(&w)
msg := "debug work"
keysAndValues := []any{"key1", "value1", "key2", "value2"}
Debugw(msg, keysAndValues...)
require.Equal(t, "[Debug] debug work key1=value1 key2=value2\n", string(w.b))
}
func Test_Infow(t *testing.T) {
initDefaultLogger()
var w byteSliceWriter
SetOutput(&w)
msg := "info work"
keysAndValues := []any{"key1", "value1", "key2", "value2"}
Infow(msg, keysAndValues...)
require.Equal(t, "[Info] info work key1=value1 key2=value2\n", string(w.b))
}
func Test_Warnw(t *testing.T) {
initDefaultLogger()
var w byteSliceWriter
SetOutput(&w)
msg := "warning work"
keysAndValues := []any{"key1", "value1", "key2", "value2"}
Warnw(msg, keysAndValues...)
require.Equal(t, "[Warn] warning work key1=value1 key2=value2\n", string(w.b))
}
func Test_Errorw(t *testing.T) {
initDefaultLogger()
var w byteSliceWriter
SetOutput(&w)
msg := "error work"
keysAndValues := []any{"key1", "value1", "key2", "value2"}
Errorw(msg, keysAndValues...)
require.Equal(t, "[Error] error work key1=value1 key2=value2\n", string(w.b))
}
func Test_Panicw(t *testing.T) {
initDefaultLogger()
var w byteSliceWriter
SetOutput(&w)
msg := "panic work"
keysAndValues := []any{"key1", "value1", "key2", "value2"}
require.Panics(t, func() {
Panicw(msg, keysAndValues...)
})
require.Equal(t, "[Panic] panic work key1=value1 key2=value2\n", string(w.b))
}
func Test_Tracew(t *testing.T) {
initDefaultLogger()
var w byteSliceWriter
SetOutput(&w)
msg := "trace work"
keysAndValues := []any{"key1", "value1", "key2", "value2"}
Tracew(msg, keysAndValues...)
require.Equal(t, "[Trace] trace work key1=value1 key2=value2\n", string(w.b))
}
type stringKey struct {
value string
}
func (k stringKey) String() string {
return "key:" + k.value
}
func Test_DefaultLoggerNonStringKeys(t *testing.T) {
t.Parallel()
t.Run("Tracew with non-string keys", func(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
l := &defaultLogger{
stdlog: log.New(&buf, "", 0),
level: LevelTrace,
depth: 4,
}
require.NotPanics(t, func() {
l.Tracew("trace", 123, "value", stringKey{value: "alpha"}, 42)
})
require.Equal(t, "[Trace] trace 123=value key:alpha=42\n", buf.String())
})
t.Run("Infow with non-string keys", func(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
l := &defaultLogger{
stdlog: log.New(&buf, "", 0),
level: LevelTrace,
depth: 4,
}
require.NotPanics(t, func() {
l.Infow("info", 456, "value", stringKey{value: "beta"}, 7)
})
require.Equal(t, "[Info] info 456=value key:beta=7\n", buf.String())
})
}
func Benchmark_LogfKeyAndValues(b *testing.B) {
tests := []struct {
name string
format string
keysAndValues []any
level Level
}{
{
name: "test logf with debug level and key-values",
level: LevelDebug,
format: "",
keysAndValues: []any{"name", "Bob", "age", 30},
},
{
name: "test logf with info level and key-values",
level: LevelInfo,
format: "",
keysAndValues: []any{"status", "ok", "code", 200},
},
{
name: "test logf with warn level and key-values",
level: LevelWarn,
format: "",
keysAndValues: []any{"error", "not found", "id", 123},
},
{
name: "test logf with format and key-values",
level: LevelWarn,
format: "test",
keysAndValues: []any{"error", "not found", "id", 123},
},
{
name: "test logf with one key",
level: LevelWarn,
format: "",
keysAndValues: []any{"error"},
},
}
for _, tt := range tests {
b.Run(tt.name, func(bb *testing.B) {
var buf bytes.Buffer
l := &defaultLogger{
stdlog: log.New(&buf, "", 0),
level: tt.level,
depth: 4,
}
bb.ReportAllocs()
for bb.Loop() {
l.privateLogw(tt.level, tt.format, tt.keysAndValues)
}
})
}
}
func Benchmark_LogfKeyAndValues_Parallel(b *testing.B) {
tests := []struct {
name string
format string
keysAndValues []any
level Level
}{
{
name: "debug level with key-values",
level: LevelDebug,
format: "",
keysAndValues: []any{"name", "Bob", "age", 30},
},
{
name: "info level with key-values",
level: LevelInfo,
format: "",
keysAndValues: []any{"status", "ok", "code", 200},
},
{
name: "warn level with key-values",
level: LevelWarn,
format: "",
keysAndValues: []any{"error", "not found", "id", 123},
},
{
name: "warn level with format and key-values",
level: LevelWarn,
format: "test",
keysAndValues: []any{"error", "not found", "id", 123},
},
{
name: "warn level with one key",
level: LevelWarn,
format: "",
keysAndValues: []any{"error"},
},
}
for _, tt := range tests {
b.Run(tt.name, func(bb *testing.B) {
bb.ReportAllocs()
bb.ResetTimer()
bb.RunParallel(func(pb *testing.PB) {
var buf bytes.Buffer
l := &defaultLogger{
stdlog: log.New(&buf, "", 0),
level: tt.level,
depth: 4,
}
for pb.Next() {
l.privateLogw(tt.level, tt.format, tt.keysAndValues)
}
})
})
}
}
================================================
FILE: log/fiberlog.go
================================================
package log
import (
"context"
"io"
)
// Fatal calls the default logger's Fatal method and then os.Exit(1).
func Fatal(v ...any) {
logger.Fatal(v...)
}
// Error calls the default logger's Error method.
func Error(v ...any) {
logger.Error(v...)
}
// Warn calls the default logger's Warn method.
func Warn(v ...any) {
logger.Warn(v...)
}
// Info calls the default logger's Info method.
func Info(v ...any) {
logger.Info(v...)
}
// Debug calls the default logger's Debug method.
func Debug(v ...any) {
logger.Debug(v...)
}
// Trace calls the default logger's Trace method.
func Trace(v ...any) {
logger.Trace(v...)
}
// Panic calls the default logger's Panic method.
func Panic(v ...any) {
logger.Panic(v...)
}
// Fatalf calls the default logger's Fatalf method and then os.Exit(1).
func Fatalf(format string, v ...any) {
logger.Fatalf(format, v...)
}
// Errorf calls the default logger's Errorf method.
func Errorf(format string, v ...any) {
logger.Errorf(format, v...)
}
// Warnf calls the default logger's Warnf method.
func Warnf(format string, v ...any) {
logger.Warnf(format, v...)
}
// Infof calls the default logger's Infof method.
func Infof(format string, v ...any) {
logger.Infof(format, v...)
}
// Debugf calls the default logger's Debugf method.
func Debugf(format string, v ...any) {
logger.Debugf(format, v...)
}
// Tracef calls the default logger's Tracef method.
func Tracef(format string, v ...any) {
logger.Tracef(format, v...)
}
// Panicf calls the default logger's Tracef method.
func Panicf(format string, v ...any) {
logger.Panicf(format, v...)
}
// Tracew logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Tracew(msg string, keysAndValues ...any) {
logger.Tracew(msg, keysAndValues...)
}
// Debugw logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Debugw(msg string, keysAndValues ...any) {
logger.Debugw(msg, keysAndValues...)
}
// Infow logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Infow(msg string, keysAndValues ...any) {
logger.Infow(msg, keysAndValues...)
}
// Warnw logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Warnw(msg string, keysAndValues ...any) {
logger.Warnw(msg, keysAndValues...)
}
// Errorw logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Errorw(msg string, keysAndValues ...any) {
logger.Errorw(msg, keysAndValues...)
}
// Fatalw logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Fatalw(msg string, keysAndValues ...any) {
logger.Fatalw(msg, keysAndValues...)
}
// Panicw logs a message with some additional context. The variadic key-value
// pairs are treated as they are privateLog With.
func Panicw(msg string, keysAndValues ...any) {
logger.Panicw(msg, keysAndValues...)
}
// WithContext binds the default logger to the provided context and returns the
// contextualized logger.
func WithContext(ctx context.Context) CommonLogger {
return logger.WithContext(ctx)
}
// SetLogger sets the default logger and the system logger.
// Note that this method is not concurrent-safe and must not be called
// after the use of DefaultLogger and global functions from this package.
func SetLogger[T any](v AllLogger[T]) {
logger = v
}
// SetOutput sets the output of default logger and system logger. By default, it is stderr.
func SetOutput(w io.Writer) {
logger.SetOutput(w)
}
// SetLevel sets the level of logs below which logs will not be output.
// The default logger is LevelTrace.
// Note that this method is not concurrent-safe.
func SetLevel(lv Level) {
logger.SetLevel(lv)
}
================================================
FILE: log/fiberlog_test.go
================================================
package log
import (
"log"
"os"
"testing"
"github.com/stretchr/testify/require"
)
func Test_DefaultSystemLogger(t *testing.T) {
t.Parallel()
defaultL := DefaultLogger[*log.Logger]()
require.Equal(t, logger, defaultL)
}
func Test_SetLogger(t *testing.T) {
setLog := &defaultLogger{
stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds),
depth: 6,
}
SetLogger(setLog)
require.Equal(t, logger, setLog)
}
func Test_Fiberlog_SetLevel(t *testing.T) {
mockLogger := &defaultLogger{}
SetLogger(mockLogger)
// Test cases
testCases := []struct {
name string
level Level
expected Level
}{
{
name: "Test case 1",
level: LevelDebug,
expected: LevelDebug,
},
{
name: "Test case 2",
level: LevelInfo,
expected: LevelInfo,
},
{
name: "Test case 3",
level: LevelWarn,
expected: LevelWarn,
},
{
name: "Test case 4",
level: LevelError,
expected: LevelError,
},
{
name: "Test case 5",
level: LevelFatal,
expected: LevelFatal,
},
}
// Run tests
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
SetLevel(tc.level)
require.Equal(t, tc.expected, mockLogger.level)
})
}
}
func Benchmark_DefaultSystemLogger(b *testing.B) {
b.ReportAllocs()
for b.Loop() {
_ = DefaultLogger[*log.Logger]()
}
}
func Benchmark_SetLogger(b *testing.B) {
setLog := &defaultLogger{
stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds),
depth: 6,
}
b.ReportAllocs()
for b.Loop() {
SetLogger(setLog)
}
}
func Benchmark_Fiberlog_SetLevel(b *testing.B) {
mockLogger := &defaultLogger{}
SetLogger(mockLogger)
// Test cases
testCases := []struct {
name string
level Level
expected Level
}{
{
name: "Test case 1",
level: LevelDebug,
expected: LevelDebug,
},
{
name: "Test case 2",
level: LevelInfo,
expected: LevelInfo,
},
{
name: "Test case 3",
level: LevelWarn,
expected: LevelWarn,
},
{
name: "Test case 4",
level: LevelError,
expected: LevelError,
},
{
name: "Test case 5",
level: LevelFatal,
expected: LevelFatal,
},
}
for _, tc := range testCases {
b.ReportAllocs()
b.Run(tc.name, func(b *testing.B) {
for b.Loop() {
SetLevel(tc.level)
}
})
}
}
func Benchmark_DefaultSystemLogger_Parallel(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = DefaultLogger[*log.Logger]()
}
})
}
func Benchmark_SetLogger_Parallel(b *testing.B) {
setLog := &defaultLogger{
stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds),
depth: 6,
}
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
SetLogger(setLog)
}
})
}
func Benchmark_Fiberlog_SetLevel_Parallel(b *testing.B) {
mockLogger := &defaultLogger{}
SetLogger(mockLogger)
// Test cases
testCases := []struct {
name string
level Level
expected Level
}{
{
name: "Test case 1",
level: LevelDebug,
expected: LevelDebug,
},
{
name: "Test case 2",
level: LevelInfo,
expected: LevelInfo,
},
{
name: "Test case 3",
level: LevelWarn,
expected: LevelWarn,
},
{
name: "Test case 4",
level: LevelError,
expected: LevelError,
},
{
name: "Test case 5",
level: LevelFatal,
expected: LevelFatal,
},
}
for _, tc := range testCases {
b.Run(tc.name+"_Parallel", func(bb *testing.B) {
bb.ReportAllocs()
bb.ResetTimer()
bb.RunParallel(func(pb *testing.PB) {
for pb.Next() {
SetLevel(tc.level)
}
})
})
}
}
================================================
FILE: log/log.go
================================================
package log
import (
"context"
"fmt"
"io"
"log"
"os"
)
// baseLogger defines the minimal logger functionality required by the package.
// It allows storing any logger implementation regardless of its generic type.
type baseLogger interface {
CommonLogger
SetLevel(Level)
SetOutput(io.Writer)
WithContext(ctx context.Context) CommonLogger
}
var logger baseLogger = &defaultLogger{
stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds),
depth: 4,
}
// Logger is a logger interface that provides logging function with levels.
type Logger interface {
Trace(v ...any)
Debug(v ...any)
Info(v ...any)
Warn(v ...any)
Error(v ...any)
Fatal(v ...any)
Panic(v ...any)
}
// FormatLogger is a logger interface that output logs with a format.
type FormatLogger interface {
Tracef(format string, v ...any)
Debugf(format string, v ...any)
Infof(format string, v ...any)
Warnf(format string, v ...any)
Errorf(format string, v ...any)
Fatalf(format string, v ...any)
Panicf(format string, v ...any)
}
// WithLogger is a logger interface that output logs with a message and key-value pairs.
type WithLogger interface {
Tracew(msg string, keysAndValues ...any)
Debugw(msg string, keysAndValues ...any)
Infow(msg string, keysAndValues ...any)
Warnw(msg string, keysAndValues ...any)
Errorw(msg string, keysAndValues ...any)
Fatalw(msg string, keysAndValues ...any)
Panicw(msg string, keysAndValues ...any)
}
// CommonLogger is the set of logging operations available across Fiber's
// logging implementations.
type CommonLogger interface {
Logger
FormatLogger
WithLogger
}
// ConfigurableLogger provides methods to config a logger.
type ConfigurableLogger[T any] interface {
// SetLevel sets logging level.
//
// Available levels: Trace, Debug, Info, Warn, Error, Fatal, Panic.
SetLevel(level Level)
// SetOutput sets the logger output.
SetOutput(w io.Writer)
// Logger returns the logger instance. It can be used to adjust the logger configurations in case of need.
Logger() T
}
// AllLogger is the combination of Logger, FormatLogger, CtxLogger and ConfigurableLogger.
// Custom extensions can be made through AllLogger
type AllLogger[T any] interface {
CommonLogger
ConfigurableLogger[T]
// WithContext returns a new logger with the given context.
WithContext(ctx context.Context) CommonLogger
}
// Level defines the priority of a log message.
// When a logger is configured with a level, any log message with a lower
// log level (smaller by integer comparison) will not be output.
type Level int
// The levels of logs.
const (
LevelTrace Level = iota
LevelDebug
LevelInfo
LevelWarn
LevelError
LevelFatal
LevelPanic
)
var strs = []string{
"[Trace] ",
"[Debug] ",
"[Info] ",
"[Warn] ",
"[Error] ",
"[Fatal] ",
"[Panic] ",
}
func (lv Level) toString() string {
if lv >= LevelTrace && lv <= LevelPanic {
return strs[lv]
}
return fmt.Sprintf("[?%d] ", lv)
}
================================================
FILE: middleware/adaptor/adaptor.go
================================================
package adaptor
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"reflect"
"sync"
"unsafe"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttpadaptor"
)
// disableLogger implements the fasthttp Logger interface and discards log output.
type disableLogger struct{}
// Printf implements the fasthttp Logger interface and discards log output.
func (*disableLogger) Printf(string, ...any) {
}
var ctxPool = sync.Pool{
New: func() any {
return new(fasthttp.RequestCtx)
},
}
// LocalContextKey is the key used to store the user's context.Context in the fasthttp request context.
// Adapted http.Handler functions can retrieve this context using r.Context().Value(adaptor.LocalContextKey)
var localContextKey = &struct{}{}
const bufferSize = 32 * 1024
var bufferPool = sync.Pool{
New: func() any {
return new([bufferSize]byte)
},
}
var (
ErrRemoteAddrEmpty = errors.New("remote address cannot be empty")
ErrRemoteAddrTooLong = errors.New("remote address too long")
)
// HTTPHandlerFunc wraps net/http handler func to fiber handler
func HTTPHandlerFunc(h http.HandlerFunc) fiber.Handler {
return HTTPHandler(h)
}
// HTTPHandler wraps net/http handler to fiber handler
func HTTPHandler(h http.Handler) fiber.Handler {
handler := fasthttpadaptor.NewFastHTTPHandler(h)
return func(c fiber.Ctx) error {
handler(c.RequestCtx())
return nil
}
}
// HTTPHandlerWithContext is like HTTPHandler, but additionally stores Fiber’s user context in the request context
func HTTPHandlerWithContext(h http.Handler) fiber.Handler {
handler := fasthttpadaptor.NewFastHTTPHandler(h)
return func(c fiber.Ctx) error {
// Store the Fiber user context (c.Context()) in the fasthttp request context
// so adapted net/http handlers can retrieve it via adaptor.LocalContextFromHTTPRequest(r)
c.RequestCtx().SetUserValue(localContextKey, c.Context())
handler(c.RequestCtx())
return nil
}
}
// LocalContextFromHTTPRequest extracts the Fiber user context previously stored into r.Context() by the adaptor.
func LocalContextFromHTTPRequest(r *http.Request) (context.Context, bool) {
if r == nil {
return nil, false
}
ctx, err := r.Context().Value(localContextKey).(context.Context)
return ctx, err
}
// ConvertRequest converts a fiber.Ctx to a http.Request.
// forServer should be set to true when the http.Request is going to be passed to a http.Handler.
func ConvertRequest(c fiber.Ctx, forServer bool) (*http.Request, error) {
var req http.Request
if err := fasthttpadaptor.ConvertRequest(c.RequestCtx(), &req, forServer); err != nil {
return nil, err //nolint:wrapcheck // This must not be wrapped
}
return &req, nil
}
// CopyContextToFiberContext copies the values of context.Context to a fasthttp.RequestCtx.
// This function safely handles struct fields, using unsafe operations only when necessary for unexported fields.
//
// Deprecated: This function uses reflection and unsafe pointers; consider using explicit context passing.
func CopyContextToFiberContext(src any, requestContext *fasthttp.RequestCtx) {
if requestContext == nil {
return
}
v := reflect.ValueOf(src)
if !v.IsValid() {
return
}
// Deref pointer chains
for v.Kind() == reflect.Ptr {
if v.IsNil() {
return
}
v = v.Elem()
}
t := v.Type()
if t.Kind() != reflect.Struct {
return
}
// Ensure addressable for safe unsafe-access of unexported fields
if !v.CanAddr() {
tmp := reflect.New(t)
tmp.Elem().Set(v)
v = tmp.Elem()
}
contextValues := v
contextKeys := t
var lastKey any
for i := 0; i < contextValues.NumField(); i++ {
reflectValue := contextValues.Field(i)
reflectField := contextKeys.Field(i)
if reflectField.Name == "noCopy" {
break
}
// Avoid unsafe access for unexported fields; use safe reflection where possible
if !reflectValue.CanInterface() {
/* #nosec */
reflectValue = reflect.NewAt(reflectValue.Type(), unsafe.Pointer(reflectValue.UnsafeAddr())).Elem()
}
switch reflectField.Name {
case "Context":
CopyContextToFiberContext(reflectValue.Interface(), requestContext)
case "key":
lastKey = reflectValue.Interface()
case "val":
if lastKey != nil {
requestContext.SetUserValue(lastKey, reflectValue.Interface())
lastKey = nil
}
default:
continue
}
}
}
// HTTPMiddleware wraps net/http middleware to fiber middleware
func HTTPMiddleware(mw func(http.Handler) http.Handler) fiber.Handler {
return func(c fiber.Ctx) error {
var next bool
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
next = true
c.Request().Header.SetMethod(r.Method)
c.Request().SetRequestURI(r.RequestURI)
c.Request().SetHost(r.Host)
c.Request().Header.SetHost(r.Host)
// Remove all cookies before setting, see https://github.com/valyala/fasthttp/pull/1864
c.Request().Header.DelAllCookies()
for key, val := range r.Header {
for _, v := range val {
c.Request().Header.Set(key, v)
}
}
CopyContextToFiberContext(r.Context(), c.RequestCtx())
})
if err := HTTPHandler(mw(nextHandler))(c); err != nil {
return err
}
if next {
return c.Next()
}
return nil
}
}
// FiberHandler wraps fiber handler to net/http handler
func FiberHandler(h fiber.Handler) http.Handler {
return FiberHandlerFunc(h)
}
// FiberHandlerFunc wraps fiber handler to net/http handler func
func FiberHandlerFunc(h fiber.Handler) http.HandlerFunc {
return handlerFunc(fiber.New(), h)
}
// FiberApp wraps fiber app to net/http handler func
func FiberApp(app *fiber.App) http.HandlerFunc {
return handlerFunc(app)
}
func isUnixNetwork(network string) bool {
return network == "unix" || network == "unixgram" || network == "unixpacket"
}
func resolveRemoteAddr(remoteAddr string, localAddr any) (net.Addr, error) {
if addr, ok := localAddr.(net.Addr); ok && isUnixNetwork(addr.Network()) {
return addr, nil
}
// Validate input to prevent malformed addresses
if remoteAddr == "" {
return nil, ErrRemoteAddrEmpty
}
resolved, err := net.ResolveTCPAddr("tcp", remoteAddr)
if err == nil {
return resolved, nil
}
var addrErr *net.AddrError
if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" {
if len(remoteAddr) > 253 { // Max hostname length
return nil, ErrRemoteAddrTooLong
}
remoteAddr = net.JoinHostPort(remoteAddr, "80")
resolved, err2 := net.ResolveTCPAddr("tcp", remoteAddr)
if err2 != nil {
return nil, fmt.Errorf("failed to resolve TCP address after adding port: %w", err2)
}
return resolved, nil
}
return nil, fmt.Errorf("failed to resolve TCP address: %w", err)
}
func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)
// Convert net/http -> fasthttp request with size limit
maxBodySize := int64(app.Config().BodyLimit)
if r.Body != nil {
if r.ContentLength > maxBodySize {
http.Error(w, utils.StatusMessage(fiber.StatusRequestEntityTooLarge), fiber.StatusRequestEntityTooLarge)
return
}
limitedReader := io.LimitReader(r.Body, maxBodySize)
n, err := io.Copy(req.BodyWriter(), limitedReader)
req.Header.SetContentLength(int(n))
if err != nil {
http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError)
return
}
}
req.Header.SetMethod(r.Method)
req.SetRequestURI(r.RequestURI)
req.SetHost(r.Host)
req.Header.SetHost(r.Host)
for key, val := range r.Header {
for _, v := range val {
req.Header.Set(key, v)
}
}
remoteAddr, err := resolveRemoteAddr(r.RemoteAddr, r.Context().Value(http.LocalAddrContextKey))
if err != nil {
remoteAddr = nil // Fallback to nil
}
// New fasthttp Ctx from pool
fctx := ctxPool.Get().(*fasthttp.RequestCtx) //nolint:forcetypeassert,errcheck // not needed
fctx.Response.Reset()
fctx.Request.Reset()
defer ctxPool.Put(fctx)
fctx.Init(req, remoteAddr, &disableLogger{})
if len(h) > 0 {
// New fiber Ctx
ctx := app.AcquireCtx(fctx)
defer app.ReleaseCtx(ctx)
// Execute fiber Ctx
err := h[0](ctx)
if err != nil {
_ = app.Config().ErrorHandler(ctx, err) //nolint:errcheck // not needed
}
} else {
// Execute fasthttp Ctx though app.Handler
app.Handler()(fctx)
}
// Convert fasthttp Ctx -> net/http
for k, v := range fctx.Response.Header.All() {
w.Header().Add(string(k), string(v))
}
w.WriteHeader(fctx.Response.StatusCode())
// Check if streaming is not possible or unnecessary.
bodyStream := fctx.Response.BodyStream()
flusher, ok := w.(http.Flusher)
if !ok || bodyStream == nil {
_, _ = w.Write(fctx.Response.Body()) //nolint:errcheck // not needed
return
}
// Stream fctx.Response.BodyStream() -> w
// in chunks.
bufPtr, ok := bufferPool.Get().(*[bufferSize]byte)
if !ok {
panic(fmt.Errorf("failed to type-assert to *[%d]byte", bufferSize))
}
defer bufferPool.Put(bufPtr)
buf := bufPtr[:]
for {
n, err := bodyStream.Read(buf)
if n > 0 {
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
break
}
flusher.Flush()
}
if err != nil {
break
}
}
}
}
================================================
FILE: middleware/adaptor/adaptor_test.go
================================================
//nolint:contextcheck,revive // Much easier to just ignore memory leaks in tests
package adaptor
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
const (
expectedRequestURI = "/foo/bar?baz=123"
expectedBody = "body 123 foo bar baz"
expectedHost = "foobar.com"
expectedRemoteAddr = "1.2.3.4:6789"
)
func Test_HTTPHandler(t *testing.T) {
t.Parallel()
expectedMethod := fiber.MethodPost
expectedProto := "HTTP/1.1"
expectedProtoMajor := 1
expectedProtoMinor := 1
expectedContentLength := len(expectedBody)
expectedHeader := map[string]string{
"Foo-Bar": "baz",
"Abc": "defg",
"XXX-Remote-Addr": "123.43.4543.345",
}
expectedURL, err := url.ParseRequestURI(expectedRequestURI)
require.NoError(t, err)
type contextKeyType string
expectedContextKey := contextKeyType("contextKey")
expectedContextValue := "contextValue"
callsCount := 0
nethttpH := func(w http.ResponseWriter, r *http.Request) {
callsCount++
assert.Equal(t, expectedMethod, r.Method, "Method")
assert.Equal(t, expectedProto, r.Proto, "Proto")
assert.Equal(t, expectedProtoMajor, r.ProtoMajor, "ProtoMajor")
assert.Equal(t, expectedProtoMinor, r.ProtoMinor, "ProtoMinor")
assert.Equal(t, expectedRequestURI, r.RequestURI, "RequestURI")
assert.Equal(t, expectedContentLength, int(r.ContentLength), "ContentLength")
assert.Empty(t, r.TransferEncoding, "TransferEncoding")
assert.Equal(t, expectedHost, r.Host, "Host")
assert.Equal(t, expectedRemoteAddr, r.RemoteAddr, "RemoteAddr")
body, readErr := io.ReadAll(r.Body)
assert.NoError(t, readErr)
assert.Equal(t, expectedBody, string(body), "Body")
assert.Equal(t, expectedURL, r.URL, "URL")
assert.Equal(t, expectedContextValue, r.Context().Value(expectedContextKey), "Context")
for k, expectedV := range expectedHeader {
v := r.Header.Get(k)
assert.Equal(t, expectedV, v, "Header")
}
w.Header().Set("Header1", "value1")
w.Header().Set("Header2", "value2")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "request body is %q", body)
}
fiberH := HTTPHandlerFunc(http.HandlerFunc(nethttpH))
fiberH = setFiberContextValueMiddleware(fiberH, expectedContextKey, expectedContextValue)
var fctx fasthttp.RequestCtx
var req fasthttp.Request
req.Header.SetMethod(expectedMethod)
req.SetRequestURI(expectedRequestURI)
req.Header.SetHost(expectedHost)
req.BodyWriter().Write([]byte(expectedBody)) //nolint:errcheck // not needed
for k, v := range expectedHeader {
req.Header.Set(k, v)
}
remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr)
require.NoError(t, err)
fctx.Init(&req, remoteAddr, &disableLogger{})
app := fiber.New()
ctx := app.AcquireCtx(&fctx)
defer app.ReleaseCtx(ctx)
err = fiberH(ctx)
require.NoError(t, err)
require.Equal(t, 1, callsCount, "callsCount")
resp := &fctx.Response
require.Equal(t, http.StatusBadRequest, resp.StatusCode(), "StatusCode")
require.Equal(t, "value1", string(resp.Header.Peek("Header1")), "Header1")
require.Equal(t, "value2", string(resp.Header.Peek("Header2")), "Header2")
expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
require.Equal(t, expectedResponseBody, string(resp.Body()), "Body")
}
func Test_HTTPHandler_Flush(t *testing.T) {
t.Parallel()
expectedMethod := fiber.MethodPost
expectedProto := "HTTP/1.1"
expectedProtoMajor := 1
expectedProtoMinor := 1
expectedContentLength := len(expectedBody)
expectedHeader := map[string]string{
"Foo-Bar": "baz",
"Abc": "defg",
"XXX-Remote-Addr": "123.43.4543.345",
}
expectedURL, err := url.ParseRequestURI(expectedRequestURI)
require.NoError(t, err)
type contextKeyType string
expectedContextKey := contextKeyType("contextKey")
expectedContextValue := "contextValue"
callsCount := 0
nethttpH := func(w http.ResponseWriter, r *http.Request) {
callsCount++
assert.Equal(t, expectedMethod, r.Method, "Method")
assert.Equal(t, expectedProto, r.Proto, "Proto")
assert.Equal(t, expectedProtoMajor, r.ProtoMajor, "ProtoMajor")
assert.Equal(t, expectedProtoMinor, r.ProtoMinor, "ProtoMinor")
assert.Equal(t, expectedRequestURI, r.RequestURI, "RequestURI")
assert.Equal(t, expectedContentLength, int(r.ContentLength), "ContentLength")
assert.Empty(t, r.TransferEncoding, "TransferEncoding")
assert.Equal(t, expectedHost, r.Host, "Host")
assert.Equal(t, expectedRemoteAddr, r.RemoteAddr, "RemoteAddr")
body, readErr := io.ReadAll(r.Body)
assert.NoError(t, readErr)
assert.Equal(t, expectedBody, string(body), "Body")
assert.Equal(t, expectedURL, r.URL, "URL")
assert.Equal(t, expectedContextValue, r.Context().Value(expectedContextKey), "Context")
for k, expectedV := range expectedHeader {
v := r.Header.Get(k)
assert.Equal(t, expectedV, v, "Header")
}
w.Header().Set("Header1", "value1")
w.Header().Set("Header2", "value2")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "request body is ")
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("w does not implement http.Flusher")
}
flusher.Flush()
fmt.Fprintf(w, "%q", body)
}
fiberH := HTTPHandlerFunc(http.HandlerFunc(nethttpH))
fiberH = setFiberContextValueMiddleware(fiberH, expectedContextKey, expectedContextValue)
var fctx fasthttp.RequestCtx
var req fasthttp.Request
req.Header.SetMethod(expectedMethod)
req.SetRequestURI(expectedRequestURI)
req.Header.SetHost(expectedHost)
req.BodyWriter().Write([]byte(expectedBody)) //nolint:errcheck // not needed
for k, v := range expectedHeader {
req.Header.Set(k, v)
}
remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr)
require.NoError(t, err)
fctx.Init(&req, remoteAddr, &disableLogger{})
app := fiber.New()
ctx := app.AcquireCtx(&fctx)
defer app.ReleaseCtx(ctx)
err = fiberH(ctx)
require.NoError(t, err)
require.Equal(t, 1, callsCount, "callsCount")
resp := &fctx.Response
require.Equal(t, http.StatusBadRequest, resp.StatusCode(), "StatusCode")
require.Equal(t, "value1", string(resp.Header.Peek("Header1")), "Header1")
require.Equal(t, "value2", string(resp.Header.Peek("Header2")), "Header2")
expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
require.Equal(t, expectedResponseBody, string(resp.Body()), "Body")
}
func Test_HTTPHandler_Flush_App_Test(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/", HTTPHandlerFunc(func(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("w does not implement http.Flusher")
}
w.WriteHeader(fiber.StatusOK)
fmt.Fprintf(w, "Hello ")
flusher.Flush()
time.Sleep(500 * time.Millisecond)
fmt.Fprintf(w, "World!")
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // not needed
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Hello World!", string(body))
}
func Test_HTTPHandler_App_Test_Interrupted(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/", HTTPHandlerFunc(func(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatalf("w does not implement http.Flusher")
}
w.WriteHeader(fiber.StatusOK)
fmt.Fprintf(w, "Hello ")
flusher.Flush()
time.Sleep(500 * time.Millisecond)
fmt.Fprintf(w, "World!")
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{
Timeout: 200 * time.Millisecond,
FailOnTimeout: true, // Changed to true to test interrupted behavior
})
// With FailOnTimeout: true, we should get a timeout error
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
require.Nil(t, resp)
}
func Test_LocalContextFromHTTPRequest(t *testing.T) {
t.Parallel()
t.Run("nil request", func(t *testing.T) {
t.Parallel()
ctx, ok := LocalContextFromHTTPRequest(nil)
require.False(t, ok)
require.Nil(t, ctx)
})
t.Run("request without stored context key", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
ctx, ok := LocalContextFromHTTPRequest(req)
require.False(t, ok)
require.Nil(t, ctx)
})
t.Run("request with stored context key", func(t *testing.T) {
t.Parallel()
expectedCtx := context.WithValue(context.Background(), contextKey("k"), "v")
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody).WithContext(
context.WithValue(context.Background(), localContextKey, expectedCtx),
)
ctx, ok := LocalContextFromHTTPRequest(req)
require.True(t, ok)
require.Equal(t, expectedCtx, ctx)
})
}
func Test_HTTPHandlerWithContext_local_context(t *testing.T) {
t.Parallel()
app := fiber.New()
// unique type for avoiding collisions in context
type key struct{}
var testKey key
const testVal string = "test-value"
// a middleware to add a value to the local context
app.Use(func(c fiber.Ctx) error {
ctx := context.WithValue(c.Context(), testKey, testVal)
c.SetContext(ctx)
return c.Next()
})
// a handler that checks if the value has been appended to the local context
app.Get("/", HTTPHandlerWithContext(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, ok := LocalContextFromHTTPRequest(r)
if !ok {
http.Error(w, "local context not found", http.StatusInternalServerError)
return
}
val, ok := ctx.Value(testKey).(string)
if !ok {
http.Error(w, "invalid context value", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte(val)); err != nil {
t.Logf("write failed: %v", err)
}
})))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{
Timeout: 200 * time.Millisecond,
FailOnTimeout: false,
})
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // no need
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, testVal, string(body))
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
type contextKey string
func (c contextKey) String() string {
return "test-" + string(c)
}
var (
TestContextKey = contextKey("TestContextKey")
TestContextSecondKey = contextKey("TestContextSecondKey")
)
func Test_HTTPMiddleware(t *testing.T) {
t.Parallel()
tests := []struct {
name string
url string
method string
statusCode int
}{
{
name: "Should return 200",
url: "/",
method: "POST",
statusCode: 200,
},
{
name: "Should return 405",
url: "/",
method: "GET",
statusCode: 405,
},
{
name: "Should return 400",
url: "/unknown",
method: "POST",
statusCode: 404,
},
}
nethttpMW := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
r = r.WithContext(context.WithValue(r.Context(), TestContextKey, "okay"))
r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "not_okay"))
r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "okay"))
next.ServeHTTP(w, r)
})
}
app := fiber.New()
app.Use(HTTPMiddleware(nethttpMW))
app.Post("/", func(c fiber.Ctx) error {
value := c.RequestCtx().Value(TestContextKey)
val, ok := value.(string)
if !ok {
t.Error("unexpected error on type-assertion")
}
if value != nil {
c.Set("context_okay", val)
}
value = c.RequestCtx().Value(TestContextSecondKey)
if value != nil {
val, ok := value.(string)
if !ok {
t.Error("unexpected error on type-assertion")
}
c.Set("context_second_okay", val)
}
return c.SendStatus(fiber.StatusOK)
})
for _, tt := range tests {
req, err := http.NewRequestWithContext(context.Background(), tt.method, tt.url, http.NoBody)
req.Host = expectedHost
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, tt.statusCode, resp.StatusCode, "StatusCode")
}
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", http.NoBody)
req.Host = expectedHost
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, "okay", resp.Header.Get("context_okay"))
require.Equal(t, "okay", resp.Header.Get("context_second_okay"))
}
func Test_HTTPMiddlewareWithCookies(t *testing.T) {
t.Parallel()
const (
cookieHeader = "Cookie"
setCookieHeader = "Set-Cookie"
cookieOneName = "cookieOne"
cookieTwoName = "cookieTwo"
cookieOneValue = "valueCookieOne"
cookieTwoValue = "valueCookieTwo"
)
nethttpMW := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
next.ServeHTTP(w, r)
})
}
app := fiber.New()
app.Use(HTTPMiddleware(nethttpMW))
app.Post("/", func(c fiber.Ctx) error {
// RETURNING CURRENT COOKIES TO RESPONSE
cookies := strings.Split(c.Get(cookieHeader), "; ")
for _, cookie := range cookies {
c.Set(setCookieHeader, cookie)
}
return c.SendStatus(fiber.StatusOK)
})
// Test case for POST request with cookies
t.Run("POST request with cookies", func(t *testing.T) {
t.Parallel()
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", http.NoBody)
require.NoError(t, err)
req.AddCookie(&http.Cookie{Name: cookieOneName, Value: cookieOneValue})
req.AddCookie(&http.Cookie{Name: cookieTwoName, Value: cookieTwoValue})
resp, err := app.Test(req)
require.NoError(t, err)
cookies := resp.Cookies()
require.Len(t, cookies, 2)
for _, cookie := range cookies {
switch cookie.Name {
case cookieOneName:
require.Equal(t, cookieOneValue, cookie.Value)
case cookieTwoName:
require.Equal(t, cookieTwoValue, cookie.Value)
default:
t.Error("unexpected cookie key")
}
}
})
// New test case for GET request
t.Run("GET request", func(t *testing.T) {
t.Parallel()
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
})
// New test case for request without cookies
t.Run("POST request without cookies", func(t *testing.T) {
t.Parallel()
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Empty(t, resp.Cookies())
})
}
func Test_FiberHandler(t *testing.T) {
t.Parallel()
testFiberToHandlerFunc(t, false)
}
func Test_FiberHandler_BodyLimit(t *testing.T) {
t.Parallel()
tests := []struct {
name string
bodyLimit int
bodySize int
expectedStatus int
}{
{
name: "DefaultLimitExceededReturns413",
bodySize: fiber.DefaultBodyLimit + 1024,
expectedStatus: fiber.StatusRequestEntityTooLarge,
},
{
name: "CustomLimitExceededReturns413",
bodyLimit: 1 * 1024 * 1024,
bodySize: (1 * 1024 * 1024) + 1,
expectedStatus: fiber.StatusRequestEntityTooLarge,
},
{
name: "CustomLimitAllowsLargerPayload",
bodyLimit: 2 * fiber.DefaultBodyLimit,
bodySize: fiber.DefaultBodyLimit + 512,
expectedStatus: fiber.StatusOK,
},
{
name: "ZeroLimitConfigFallsBackToDefault",
bodyLimit: 0,
bodySize: fiber.DefaultBodyLimit - 256,
expectedStatus: fiber.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
app := fiber.New(fiber.Config{
BodyLimit: tt.bodyLimit,
})
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
handlerFunc := FiberApp(app)
body := bytes.Repeat([]byte("a"), tt.bodySize)
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body))
req.ContentLength = int64(len(body))
resp := httptest.NewRecorder()
handlerFunc.ServeHTTP(resp, req)
require.Equal(t, tt.expectedStatus, resp.Code)
})
}
}
func Test_FiberApp(t *testing.T) {
t.Parallel()
testFiberToHandlerFunc(t, false, fiber.New())
}
func Test_FiberHandlerDefaultPort(t *testing.T) {
t.Parallel()
testFiberToHandlerFunc(t, true)
}
func Test_FiberAppDefaultPort(t *testing.T) {
t.Parallel()
testFiberToHandlerFunc(t, true, fiber.New())
}
func testFiberToHandlerFunc(t *testing.T, checkDefaultPort bool, app ...*fiber.App) {
t.Helper()
expectedMethod := fiber.MethodPost
expectedContentLength := len(expectedBody)
expectedRemoteAddr := "1.2.3.4:6789"
if checkDefaultPort {
expectedRemoteAddr = "1.2.3.4:80"
}
expectedHeader := map[string]string{
"Foo-Bar": "baz",
"Abc": "defg",
"XXX-Remote-Addr": "123.43.4543.345",
}
expectedURL, err := url.ParseRequestURI(expectedRequestURI)
require.NoError(t, err)
callsCount := 0
fiberH := func(c fiber.Ctx) error {
callsCount++
require.Equal(t, expectedMethod, c.Method(), "Method")
require.Equal(t, expectedRequestURI, string(c.RequestCtx().RequestURI()), "RequestURI")
require.Equal(t, expectedContentLength, c.RequestCtx().Request.Header.ContentLength(), "ContentLength")
require.Equal(t, expectedHost, c.Hostname(), "Host")
require.Equal(t, expectedHost, string(c.Request().Header.Host()), "Host")
require.Equal(t, "http://"+expectedHost, c.BaseURL(), "BaseURL")
require.Equal(t, expectedRemoteAddr, c.RequestCtx().RemoteAddr().String(), "RemoteAddr")
body := string(c.Body())
require.Equal(t, expectedBody, body, "Body")
require.Equal(t, expectedURL.String(), c.OriginalURL(), "URL")
for k, expectedV := range expectedHeader {
v := c.Get(k)
require.Equal(t, expectedV, v, "Header")
}
c.Set("Header1", "value1")
c.Set("Header2", "value2")
c.Status(fiber.StatusBadRequest)
_, err := c.Write(fmt.Appendf(nil, "request body is %q", body))
return err
}
var handlerFunc http.HandlerFunc
if len(app) > 0 {
app[0].Post("/foo/bar", fiberH)
handlerFunc = FiberApp(app[0])
} else {
handlerFunc = FiberHandlerFunc(fiberH)
}
var r http.Request
r.Method = expectedMethod
r.Body = &netHTTPBody{b: []byte(expectedBody)}
r.RequestURI = expectedRequestURI
r.ContentLength = int64(expectedContentLength)
r.Host = expectedHost
r.RemoteAddr = expectedRemoteAddr
if checkDefaultPort {
r.RemoteAddr = "1.2.3.4"
}
hdr := make(http.Header)
for k, v := range expectedHeader {
hdr.Set(k, v)
}
r.Header = hdr
var w netHTTPResponseWriter
handlerFunc.ServeHTTP(&w, &r)
require.Equal(t, http.StatusBadRequest, w.StatusCode(), "StatusCode")
require.Equal(t, "value1", w.Header().Get("Header1"), "Header1")
require.Equal(t, "value2", w.Header().Get("Header2"), "Header2")
expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
require.Equal(t, expectedResponseBody, string(w.body), "Body")
}
func setFiberContextValueMiddleware(next fiber.Handler, key, value any) fiber.Handler {
return func(c fiber.Ctx) error {
c.Locals(key, value)
return next(c)
}
}
func Test_FiberHandler_RequestNilBody(t *testing.T) {
t.Parallel()
expectedMethod := fiber.MethodGet
expectedRequestURI := "/foo/bar"
expectedContentLength := 0
callsCount := 0
fiberH := func(c fiber.Ctx) error {
callsCount++
require.Equal(t, expectedMethod, c.Method(), "Method")
require.Equal(t, expectedRequestURI, string(c.RequestCtx().RequestURI()), "RequestURI")
require.Equal(t, expectedContentLength, c.RequestCtx().Request.Header.ContentLength(), "ContentLength")
_, err := c.WriteString("request body is nil")
return err
}
nethttpH := FiberHandler(fiberH)
var r http.Request
r.Method = expectedMethod
r.RequestURI = expectedRequestURI
var w netHTTPResponseWriter
nethttpH.ServeHTTP(&w, &r)
expectedResponseBody := "request body is nil"
require.Equal(t, expectedResponseBody, string(w.body), "Body")
}
type netHTTPBody struct {
b []byte
}
func (r *netHTTPBody) Read(p []byte) (int, error) {
if len(r.b) == 0 {
return 0, io.EOF
}
n := copy(p, r.b)
r.b = r.b[n:]
return n, nil
}
func (r *netHTTPBody) Close() error {
r.b = r.b[:0]
return nil
}
func createTestRequest(method, uri, remoteAddr string, body io.Reader) *http.Request {
r := &http.Request{
Method: method,
RequestURI: uri,
RemoteAddr: remoteAddr,
Header: make(http.Header),
Body: http.NoBody,
}
if body != nil {
if rc, ok := body.(io.ReadCloser); ok {
r.Body = rc
} else {
r.Body = io.NopCloser(body)
}
}
return r
}
func executeHandlerTest(_ *testing.T, handler http.HandlerFunc, req *http.Request) *netHTTPResponseWriter {
w := &netHTTPResponseWriter{}
handler.ServeHTTP(w, req)
return w
}
type netHTTPResponseWriter struct {
h http.Header
body []byte
statusCode int
}
func (w *netHTTPResponseWriter) StatusCode() int {
if w.statusCode == 0 {
return http.StatusOK
}
return w.statusCode
}
func (w *netHTTPResponseWriter) Header() http.Header {
if w.h == nil {
w.h = make(http.Header)
}
return w.h
}
func (w *netHTTPResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
}
func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
w.body = append(w.body, p...)
return len(p), nil
}
func (w *netHTTPResponseWriter) Flush() {}
func Test_ConvertRequest(t *testing.T) {
t.Parallel()
t.Run("successful conversion", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
httpReq, err := ConvertRequest(c, false)
if err != nil {
return err
}
return c.SendString("Request URL: " + httpReq.URL.String())
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test?hello=world&another=test", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, http.StatusOK, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Request URL: /test?hello=world&another=test", string(body))
})
t.Run("conversion error handling", func(t *testing.T) {
t.Parallel()
// Test error case by creating a context with an invalid URL that will cause fasthttpadaptor.ConvertRequest to fail
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// Create a malformed request URI that should cause conversion to fail
ctx.Request().SetRequestURI("http://[::1:bad:url") // Invalid URL format
ctx.Request().Header.SetMethod(fiber.MethodGet)
_, err := ConvertRequest(ctx, true) // Use forServer=true which does more validation
if err == nil {
// If the above doesn't fail, try a different approach
ctx.Request().SetRequestURI("\x00\x01\x02") // Invalid characters in URI
_, err = ConvertRequest(ctx, true)
}
// Note: This test may pass if fasthttpadaptor is very permissive
// The important thing is that our function doesn't panic
if err != nil {
require.Error(t, err, "Expected error from fasthttpadaptor.ConvertRequest")
}
})
}
func Test_CopyContextToFiberContext(t *testing.T) {
t.Parallel()
t.Run("unsupported context type", func(t *testing.T) {
t.Parallel()
// Test with non-struct context (should return early)
var fctx fasthttp.RequestCtx
stringContext := "not a struct"
// This should not panic and should handle the non-struct gracefully
CopyContextToFiberContext(&stringContext, &fctx)
// No assertions needed - just ensuring it doesn't panic
})
t.Run("context with unknown field", func(t *testing.T) {
t.Parallel()
// Test the default case (continue statement coverage)
type customContext struct {
UnknownField string
}
var fctx fasthttp.RequestCtx
ctx := customContext{UnknownField: "test"}
// This should hit the default case and continue
CopyContextToFiberContext(&ctx, &fctx)
// No assertions needed - just ensuring it doesn't panic and continues
})
t.Run("invalid src", func(t *testing.T) {
t.Parallel()
var fctx fasthttp.RequestCtx
CopyContextToFiberContext(nil, &fctx)
// Add assertion to ensure no panic and coverage is detected
assert.NotNil(t, &fctx)
})
t.Run("nil request context", func(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), contextKey("nil-request-context"), "value")
require.NotPanics(t, func() {
CopyContextToFiberContext(ctx, nil)
})
})
t.Run("nil pointer", func(t *testing.T) {
t.Parallel()
var nilPtr *context.Context // Nil pointer to a context
var fctx fasthttp.RequestCtx
CopyContextToFiberContext(nilPtr, &fctx)
// Add assertion to ensure no panic and coverage is detected
assert.NotNil(t, &fctx)
})
t.Run("copies key value pairs", func(t *testing.T) {
t.Parallel()
var fctx fasthttp.RequestCtx
key := contextKey("copy-key")
expectedValue := "copy-value"
ctx := context.WithValue(context.Background(), key, expectedValue)
CopyContextToFiberContext(ctx, &fctx)
require.Equal(t, expectedValue, fctx.UserValue(key))
})
t.Run("nested context wrappers", func(t *testing.T) {
t.Parallel()
var fctx fasthttp.RequestCtx
keyA := contextKey("nested-a")
keyB := contextKey("nested-b")
baseCtx := context.WithValue(context.Background(), keyA, "value-a")
cancelCtx, cancel := context.WithCancel(baseCtx)
t.Cleanup(cancel)
wrappedCtx := context.WithValue(cancelCtx, keyB, "value-b")
CopyContextToFiberContext(wrappedCtx, &fctx)
require.Equal(t, "value-a", fctx.UserValue(keyA))
require.Equal(t, "value-b", fctx.UserValue(keyB))
})
t.Run("multi-level pointer", func(t *testing.T) {
t.Parallel()
var fctx fasthttp.RequestCtx
ctx := context.Background()
ptr := &ctx
doublePtr := &ptr
// Test deref pointer chains
CopyContextToFiberContext(doublePtr, &fctx)
// No assertions needed - just ensuring it doesn't panic
})
t.Run("non-addressable struct", func(t *testing.T) {
t.Parallel()
var fctx fasthttp.RequestCtx
type testStruct struct {
Field string
}
// Pass struct value directly to test addressability check
CopyContextToFiberContext(testStruct{Field: "test"}, &fctx)
// No assertions needed - just ensuring it doesn't panic and creates temporary
})
}
func Test_HTTPMiddleware_ErrorHandling(t *testing.T) {
t.Parallel()
// Test middleware that returns an error from HTTPHandler
errorMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// This will cause an error in the underlying handler
w.WriteHeader(http.StatusInternalServerError)
next.ServeHTTP(w, r)
})
}
fiberHandler := func(c fiber.Ctx) error {
return fiber.NewError(fiber.StatusBadRequest, "test error")
}
app := fiber.New()
app.Use(HTTPMiddleware(errorMiddleware))
app.Get("/error", fiberHandler)
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/error", http.NoBody))
require.NoError(t, err)
// The error should be handled by the error handler
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
}
func Test_FiberHandler_IOError(t *testing.T) {
t.Parallel()
// Test io.Copy error by using a failing reader
fiberH := func(c fiber.Ctx) error {
return c.SendString("should not reach here")
}
handlerFunc := FiberHandlerFunc(fiberH)
// Create a reader that fails
failingReader := &failingReader{}
r := &http.Request{
Method: http.MethodPost,
RequestURI: "/test",
Body: failingReader,
ContentLength: 100, // Set content length so it tries to read
Header: make(http.Header),
}
w := &netHTTPResponseWriter{}
handlerFunc.ServeHTTP(w, r)
// Should return 500 due to io.Copy error
require.Equal(t, http.StatusInternalServerError, w.StatusCode())
}
func Test_FiberHandler_WithErrorInHandler(t *testing.T) {
t.Parallel()
// Test error handling in fiber handler
fiberH := func(c fiber.Ctx) error {
return fiber.NewError(fiber.StatusTeapot, "I'm a teapot")
}
handlerFunc := FiberHandlerFunc(fiberH)
r := &http.Request{
Method: http.MethodGet,
RequestURI: "/test",
Header: make(http.Header),
Body: http.NoBody,
}
w := &netHTTPResponseWriter{}
handlerFunc.ServeHTTP(w, r)
// Should return the error status
require.Equal(t, fiber.StatusTeapot, w.StatusCode())
}
func Test_FiberHandler_WithSendStreamWriter(t *testing.T) {
t.Parallel()
// Test streaming functionality in FiberHandler using SendStreamWriter.
fiberH := func(c fiber.Ctx) error {
c.Status(fiber.StatusTeapot)
return c.SendStreamWriter(func(w *bufio.Writer) {
w.WriteString("Hello ") //nolint:errcheck // not needed
w.Flush() //nolint:errcheck // not needed
time.Sleep(200 * time.Millisecond) // Simulate a long operation
w.WriteString("World!") //nolint:errcheck // not needed
})
}
handlerFunc := FiberHandlerFunc(fiberH)
r := &http.Request{
Method: http.MethodGet,
RequestURI: "/test",
Header: make(http.Header),
Body: http.NoBody,
}
w := &netHTTPResponseWriter{}
handlerFunc.ServeHTTP(w, r)
// Should return the error status
require.Equal(t, fiber.StatusTeapot, w.StatusCode())
require.Equal(t, "Hello World!", string(w.body))
}
func Test_FiberHandler_WithInterruptedSendStreamWriter(t *testing.T) {
t.Parallel()
// Test streaming functionality to ensure data is sent even during a timeout.
fiberH := func(c fiber.Ctx) error {
c.Status(fiber.StatusTeapot)
return c.SendStreamWriter(func(w *bufio.Writer) {
w.WriteString("Hello ") //nolint:errcheck // not needed
w.Flush() //nolint:errcheck // not needed
time.Sleep(500 * time.Millisecond) // Simulate a long operation
w.WriteString("World!") //nolint:errcheck // not needed
})
}
handlerFunc := FiberHandlerFunc(fiberH)
// Start a mock HTTP server using the handlerFunc
server := &http.Server{
Handler: handlerFunc,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
}
listener, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
addr := fmt.Sprintf("http://%s", listener.Addr())
go func() {
server.Serve(listener) //nolint:errcheck // not needed
}()
defer func() {
require.NoError(t, server.Close())
}()
cc := &http.Client{
Timeout: 200 * time.Millisecond,
}
resp, err := cc.Get(addr) //nolint:noctx // ctx is not needed
require.NoError(t, err)
require.NotNil(t, resp)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
body, readErr := io.ReadAll(resp.Body)
require.ErrorIs(t, readErr, context.DeadlineExceeded)
require.Equal(t, "Hello ", string(body))
}
// failingReader always returns an error when Read is called
type failingReader struct{}
func (f *failingReader) Read(p []byte) (int, error) {
return 0, errors.New("simulated read error")
}
func (f *failingReader) Close() error {
return nil
}
// Benchmark for FiberHandlerFunc
func Benchmark_FiberHandlerFunc(b *testing.B) {
benchmarks := []struct {
name string
bodyContent []byte
}{
{
name: "No Content",
bodyContent: nil, // No body content case
},
{
name: "100KB",
bodyContent: make([]byte, 100*1024),
},
{
name: "500KB",
bodyContent: make([]byte, 500*1024),
},
{
name: "1MB",
bodyContent: make([]byte, 1*1024*1024),
},
{
name: "5MB",
bodyContent: make([]byte, 5*1024*1024),
},
{
name: "10MB",
bodyContent: make([]byte, 10*1024*1024),
},
{
name: "25MB",
bodyContent: make([]byte, 25*1024*1024),
},
{
name: "50MB",
bodyContent: make([]byte, 50*1024*1024),
},
}
fiberH := func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
}
handlerFunc := FiberHandlerFunc(fiberH)
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
w := httptest.NewRecorder()
var bodyBuffer *bytes.Buffer
// Handle the "No Content" case where bodyContent is nil
if bm.bodyContent != nil {
bodyBuffer = bytes.NewBuffer(bm.bodyContent)
} else {
bodyBuffer = bytes.NewBuffer([]byte{}) // Empty buffer for no content
}
r := http.Request{
Method: http.MethodPost,
Body: nil,
}
// Replace the empty Body with our buffer
r.Body = io.NopCloser(bodyBuffer)
defer r.Body.Close() //nolint:errcheck // not needed
b.ReportAllocs()
for b.Loop() {
handlerFunc.ServeHTTP(w, &r)
}
})
}
}
func Benchmark_FiberHandlerFunc_Parallel(b *testing.B) {
benchmarks := []struct {
name string
bodyContent []byte
}{
{
name: "No Content",
bodyContent: nil, // No body content case
},
{
name: "100KB",
bodyContent: make([]byte, 100*1024),
},
{
name: "500KB",
bodyContent: make([]byte, 500*1024),
},
{
name: "1MB",
bodyContent: make([]byte, 1*1024*1024),
},
{
name: "5MB",
bodyContent: make([]byte, 5*1024*1024),
},
{
name: "10MB",
bodyContent: make([]byte, 10*1024*1024),
},
{
name: "25MB",
bodyContent: make([]byte, 25*1024*1024),
},
{
name: "50MB",
bodyContent: make([]byte, 50*1024*1024),
},
}
fiberH := func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
}
handlerFunc := FiberHandlerFunc(fiberH)
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
var bodyBuffer *bytes.Buffer
// Handle the "No Content" case where bodyContent is nil
if bm.bodyContent != nil {
bodyBuffer = bytes.NewBuffer(bm.bodyContent)
} else {
bodyBuffer = bytes.NewBuffer([]byte{}) // Empty buffer for no content
}
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
w := httptest.NewRecorder()
r := http.Request{
Method: http.MethodPost,
Body: nil,
}
// Replace the empty Body with our buffer
r.Body = io.NopCloser(bodyBuffer)
defer r.Body.Close() //nolint:errcheck // not needed
for pb.Next() {
handlerFunc(w, &r)
}
})
})
}
}
func Benchmark_HTTPHandler(b *testing.B) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok")) //nolint:errcheck // not needed
})
var err error
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer func() {
app.ReleaseCtx(ctx)
}()
b.ReportAllocs()
b.ResetTimer()
fiberHandler := HTTPHandler(handler)
for b.Loop() {
ctx.Request().Reset()
ctx.Response().Reset()
ctx.Request().SetRequestURI("/test")
ctx.Request().Header.SetMethod("GET")
err = fiberHandler(ctx)
}
require.NoError(b, err)
}
func Benchmark_HTTPHandlerWithContext(b *testing.B) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok")) //nolint:errcheck // not needed
})
var err error
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer func() {
app.ReleaseCtx(ctx)
}()
b.ReportAllocs()
b.ResetTimer()
type key struct{}
var testKey key
ctx.SetContext(context.WithValue(ctx.Context(), testKey, "gofiber"))
fiberHandler := HTTPHandlerWithContext(handler)
for b.Loop() {
ctx.Request().Reset()
ctx.Response().Reset()
ctx.Request().SetRequestURI("/test")
ctx.Request().Header.SetMethod("GET")
err = fiberHandler(ctx)
}
require.NoError(b, err)
}
func Test_resolveRemoteAddr(t *testing.T) {
t.Parallel()
tests := []struct {
expectedErr error
localAddr any
name string
remoteAddr string
errorContains string
expectError bool
}{
{
name: "valid TCP address with port",
remoteAddr: "192.168.1.1:8080",
localAddr: nil,
expectError: false,
},
{
name: "valid TCP address without port - should add default port 80",
remoteAddr: "192.168.1.1",
localAddr: nil,
expectError: false,
},
{
name: "unix socket - should return local addr",
remoteAddr: "irrelevant",
localAddr: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
expectError: false,
},
{
name: "invalid address - should fail",
remoteAddr: "[invalid:address:format",
localAddr: nil,
expectError: true,
errorContains: "failed to resolve TCP address:",
},
{
name: "invalid address after adding port - should fail",
remoteAddr: "[invalid",
localAddr: nil,
expectError: true,
errorContains: "failed to resolve TCP address after adding port:",
},
{
name: "empty address - should fail",
remoteAddr: "",
localAddr: nil,
expectError: true,
expectedErr: ErrRemoteAddrEmpty,
},
{
name: "too long address - should fail",
remoteAddr: strings.Repeat("a", 254),
localAddr: nil,
expectError: true,
expectedErr: ErrRemoteAddrTooLong,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
addr, err := resolveRemoteAddr(tt.remoteAddr, tt.localAddr)
expectError := tt.expectedErr != nil || tt.errorContains != ""
if expectError {
require.Error(t, err)
if tt.expectedErr != nil {
require.ErrorIs(t, err, tt.expectedErr)
}
if tt.errorContains != "" {
require.Contains(t, err.Error(), tt.errorContains)
}
require.Nil(t, addr)
} else {
require.NoError(t, err)
require.NotNil(t, addr)
}
})
}
}
func Test_isUnixNetwork(t *testing.T) {
t.Parallel()
tests := []struct {
name string
network string
expected bool
}{
{"unix", "unix", true},
{"unixgram", "unixgram", true},
{"unixpacket", "unixpacket", true},
{"tcp", "tcp", false},
{"tcp4", "tcp4", false},
{"tcp6", "tcp6", false},
{"udp", "udp", false},
{"empty", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := isUnixNetwork(tt.network)
require.Equal(t, tt.expected, result)
})
}
}
func Test_FiberHandler_ErrorFallback(t *testing.T) {
t.Parallel()
// Test case where resolveRemoteAddr fails and falls back to nil
fiberH := func(c fiber.Ctx) error {
return c.SendString("success")
}
handlerFunc := FiberHandlerFunc(fiberH)
// Use helper function for cleaner test setup
req := createTestRequest(http.MethodGet, "/test", "[invalid:address:format", nil)
w := executeHandlerTest(t, handlerFunc, req)
// Should still work despite the invalid remote address
require.Equal(t, http.StatusOK, w.StatusCode())
require.Equal(t, "success", string(w.body))
}
func Test_FiberHandler_WithUnixSocket(t *testing.T) {
t.Parallel()
// Test case where request has unix socket context
fiberH := func(c fiber.Ctx) error {
return c.SendString("unix socket success")
}
handlerFunc := FiberHandlerFunc(fiberH)
// Create a context with unix socket local address
unixAddr := &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"}
ctx := context.WithValue(context.Background(), http.LocalAddrContextKey, unixAddr)
r := &http.Request{
Method: http.MethodGet,
RequestURI: "/test",
RemoteAddr: "someremoteaddr", // This will be ignored due to unix socket
Header: make(http.Header),
Body: http.NoBody,
}
r = r.WithContext(ctx)
w := &netHTTPResponseWriter{}
handlerFunc.ServeHTTP(w, r)
require.Equal(t, http.StatusOK, w.StatusCode())
require.Equal(t, "unix socket success", string(w.body))
}
func Test_FiberHandler_BodySizeLimit(t *testing.T) {
t.Parallel()
// Test body size limit enforcement
fiberH := func(c fiber.Ctx) error {
return c.SendString("processed")
}
handlerFunc := FiberHandlerFunc(fiberH)
// Create a large body exceeding limit
largeBody := make([]byte, 15*1024*1024) // 15MB > 10MB limit
req := createTestRequest(http.MethodPost, "/test", "127.0.0.1:8080", bytes.NewReader(largeBody))
req.ContentLength = int64(len(largeBody))
w := executeHandlerTest(t, handlerFunc, req)
// Should return 413 due to size limit
require.Equal(t, http.StatusRequestEntityTooLarge, w.StatusCode())
}
func Test_CopyContextToFiberContext_Safe(t *testing.T) {
t.Parallel()
t.Run("safe handling of unexported fields", func(t *testing.T) {
t.Parallel()
// Test that unexported fields are handled safely
type testContext struct {
exportedField string
unexported string // unexported
}
var fctx fasthttp.RequestCtx
ctx := testContext{exportedField: "exported", unexported: "unexported"}
// Should not panic and handle safely
CopyContextToFiberContext(&ctx, &fctx)
// No specific assertion, just ensure no panic
})
}
func TestUnixSocketAdaptor(t *testing.T) {
dir := t.TempDir()
socketPath := filepath.Join(dir, "test.sock")
defer func() {
if err := os.Remove(socketPath); err != nil {
t.Logf("cleanup failed: %v", err)
}
}()
app := fiber.New()
app.Get("/hello", func(c fiber.Ctx) error {
return c.SendString("ok")
})
handler := FiberApp(app)
listener, err := net.Listen("unix", socketPath)
if err != nil {
// Skip on platforms where the "unix" network is unsupported
if strings.Contains(err.Error(), "unknown network") ||
strings.Contains(err.Error(), "address family not supported") {
t.Skipf("Unix domain sockets not supported on this platform: %v", err)
}
t.Fatal(err)
}
defer func() {
if closeErr := listener.Close(); closeErr != nil {
t.Logf("listener close failed: %v", closeErr)
}
}()
// start server with timeouts
srv := &http.Server{
Handler: handler,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
}
done := make(chan struct{})
go func() {
if serveErr := srv.Serve(listener); serveErr != nil && serveErr != http.ErrServerClosed {
t.Errorf("http server failed: %v", serveErr)
}
close(done)
}()
conn, err := net.Dial("unix", socketPath)
require.NoError(t, err)
defer func() {
if closeErr := conn.Close(); closeErr != nil {
t.Logf("conn close failed: %v", closeErr)
}
}()
// set deadline for both write + read (2s)
require.NoError(t, conn.SetDeadline(time.Now().Add(2*time.Second)))
// write request
_, err = conn.Write([]byte("GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n"))
require.NoError(t, err)
// read response
buf := make([]byte, 1024)
n, err := conn.Read(buf)
require.NoError(t, err)
// clear deadline to avoid affecting further calls
require.NoError(t, conn.SetDeadline(time.Time{}))
raw := string(buf[:n])
t.Logf("Raw response:\n%s", raw)
require.Contains(t, raw, "HTTP/1.1 200 OK")
require.Contains(t, raw, "ok")
// now shutdown the server explicitly before waiting for done
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
require.NoError(t, srv.Shutdown(ctx))
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("server shutdown timed out")
}
}
================================================
FILE: middleware/basicauth/basicauth.go
================================================
package basicauth
import (
"encoding/base64"
"errors"
"strings"
"unicode"
"unicode/utf8"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
"golang.org/x/text/unicode/norm"
)
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
// The key for the username value stored in the context
const (
usernameKey contextKey = iota
)
const basicScheme = "Basic"
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
var cerr base64.CorruptInputError
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get authorization header and ensure it matches the Basic scheme
rawAuth := c.Get(fiber.HeaderAuthorization)
if rawAuth == "" {
return cfg.Unauthorized(c)
}
if len(rawAuth) > cfg.HeaderLimit {
return c.SendStatus(fiber.StatusRequestHeaderFieldsTooLarge)
}
if containsInvalidHeaderChars(rawAuth) {
return cfg.BadRequest(c)
}
auth := utils.TrimSpace(rawAuth)
if auth == "" {
return cfg.Unauthorized(c)
}
if len(auth) < len(basicScheme) || !utils.EqualFold(auth[:len(basicScheme)], basicScheme) {
return cfg.Unauthorized(c)
}
rest := auth[len(basicScheme):]
if len(rest) < 2 || rest[0] != ' ' || rest[1] == ' ' {
return cfg.BadRequest(c)
}
rest = rest[1:]
if strings.IndexFunc(rest, unicode.IsSpace) != -1 {
return cfg.BadRequest(c)
}
// Decode the header contents
raw, err := base64.StdEncoding.DecodeString(rest)
if err != nil {
if errors.As(err, &cerr) {
raw, err = base64.RawStdEncoding.DecodeString(rest)
}
if err != nil {
return cfg.BadRequest(c)
}
}
if !utf8.Valid(raw) {
return cfg.BadRequest(c)
}
if !norm.NFC.IsNormal(raw) {
raw = norm.NFC.Bytes(raw)
}
// Get the credentials
var creds string
if c.App().Config().Immutable {
creds = string(raw)
} else {
creds = utils.UnsafeString(raw)
}
// Check if the credentials are in the correct form
// which is "username:password".
username, password, found := strings.Cut(creds, ":")
if !found {
return cfg.BadRequest(c)
}
if containsCTL(username) || containsCTL(password) {
return cfg.BadRequest(c)
}
if cfg.Authorizer(username, password, c) {
fiber.StoreInContext(c, usernameKey, username)
return c.Next()
}
// Authentication failed
return cfg.Unauthorized(c)
}
}
func containsCTL(s string) bool {
return strings.IndexFunc(s, unicode.IsControl) != -1
}
func containsInvalidHeaderChars(s string) bool {
return strings.IndexFunc(s, func(r rune) bool {
return (r < 0x20 && r != '\t') || r == 0x7F || r >= 0x80
}) != -1
}
// UsernameFromContext returns the username found in the context.
// It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context.
// It returns an empty string if the username does not exist.
func UsernameFromContext(ctx any) string {
if username, ok := fiber.ValueFromContext[string](ctx, usernameKey); ok {
return username
}
return ""
}
================================================
FILE: middleware/basicauth/basicauth_test.go
================================================
package basicauth
import (
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"golang.org/x/crypto/bcrypt"
)
func sha256Hash(p string) string {
sum := sha256.Sum256([]byte(p))
return "{SHA256}" + base64.StdEncoding.EncodeToString(sum[:])
}
func sha512Hash(p string) string {
sum := sha512.Sum512([]byte(p))
return "{SHA512}" + base64.StdEncoding.EncodeToString(sum[:])
}
// go test -run Test_BasicAuth_Next
func Test_BasicAuth_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_Middleware_BasicAuth(t *testing.T) {
t.Parallel()
app := fiber.New()
hashedJohn := sha256Hash("doe")
hashedAdmin, err := bcrypt.GenerateFromPassword([]byte("123456"), bcrypt.MinCost)
require.NoError(t, err)
app.Use(New(Config{
Users: map[string]string{
"john": hashedJohn,
"admin": string(hashedAdmin),
},
}))
app.Get("/testauth", func(c fiber.Ctx) error {
username := UsernameFromContext(c)
return c.SendString(username)
})
tests := []struct {
url string
username string
password string
statusCode int
}{
{
url: "/testauth",
statusCode: 200,
username: "john",
password: "doe",
},
{
url: "/testauth",
statusCode: 200,
username: "admin",
password: "123456",
},
{
url: "/testauth",
statusCode: 401,
username: "ee",
password: "123456",
},
}
for _, tt := range tests {
// Base64 encode credentials for http auth header
creds := base64.StdEncoding.EncodeToString(fmt.Appendf(nil, "%s:%s", tt.username, tt.password))
req := httptest.NewRequest(fiber.MethodGet, "/testauth", http.NoBody)
req.Header.Add("Authorization", "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, tt.statusCode, resp.StatusCode)
if tt.statusCode == 200 {
require.Equal(t, tt.username, string(body))
}
}
}
func Test_BasicAuth_UsernameFromContext_Types(t *testing.T) {
t.Parallel()
app := fiber.New(fiber.Config{PassLocalsToContext: true})
app.Use(New(Config{
Users: map[string]string{
"john": sha256Hash("doe"),
},
}))
app.Get("/", func(c fiber.Ctx) error {
require.Equal(t, "john", UsernameFromContext(c))
customCtx, ok := c.(fiber.CustomCtx)
require.True(t, ok)
require.Equal(t, "john", UsernameFromContext(customCtx))
require.Equal(t, "john", UsernameFromContext(c.RequestCtx()))
require.Equal(t, "john", UsernameFromContext(c.Context()))
return c.SendStatus(fiber.StatusOK)
})
creds := base64.StdEncoding.EncodeToString([]byte("john:doe"))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add(fiber.HeaderAuthorization, "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
func Test_BasicAuth_AuthorizerCtx(t *testing.T) {
t.Parallel()
app := fiber.New()
called := false
app.Use(New(Config{
Authorizer: func(user, pass string, c fiber.Ctx) bool {
called = true
require.Equal(t, "john", user)
require.Equal(t, "doe", pass)
require.Equal(t, "/ctx", c.Path())
return true
},
}))
app.Get("/ctx", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) })
creds := base64.StdEncoding.EncodeToString([]byte("john:doe"))
req := httptest.NewRequest(fiber.MethodGet, "/ctx", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.True(t, called)
}
func Test_BasicAuth_WWWAuthenticateHeader(t *testing.T) {
t.Parallel()
app := fiber.New()
hashedJohn := sha256Hash("doe")
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
require.Equal(t, `Basic realm="Restricted", charset="UTF-8"`, resp.Header.Get(fiber.HeaderWWWAuthenticate))
}
func Test_BasicAuth_WWWAuthenticateHeader_UTF8(t *testing.T) {
t.Parallel()
app := fiber.New()
hashedJohn := sha256Hash("doe")
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}, Charset: "utf-8"}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
require.Equal(t, `Basic realm="Restricted", charset="UTF-8"`, resp.Header.Get(fiber.HeaderWWWAuthenticate))
}
func Test_BasicAuth_InvalidHeader(t *testing.T) {
t.Parallel()
app := fiber.New()
hashedJohn := sha256Hash("doe")
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}}))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Basic notbase64")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
}
func Test_BasicAuth_MissingScheme(t *testing.T) {
t.Parallel()
app := fiber.New()
hashedJohn := sha256Hash("doe")
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}}))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
require.Equal(t, `Basic realm="Restricted", charset="UTF-8"`, resp.Header.Get(fiber.HeaderWWWAuthenticate))
}
func Test_BasicAuth_MissingColon(t *testing.T) {
t.Parallel()
app := fiber.New()
hashedJohn := sha256Hash("doe")
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}}))
creds := base64.StdEncoding.EncodeToString([]byte("john"))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
}
func Test_BasicAuth_EmptyAuthorization(t *testing.T) {
t.Parallel()
app := fiber.New()
hashedJohn := sha256Hash("doe")
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}}))
cases := []string{"", " "}
for _, h := range cases {
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, h)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
}
}
func Test_BasicAuth_HeaderWhitespace(t *testing.T) {
t.Parallel()
app := fiber.New()
hashedJohn := sha256Hash("doe")
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}}))
app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) })
creds := base64.StdEncoding.EncodeToString([]byte("john:doe"))
cases := []struct {
header string
status int
}{
{"Basic " + creds, fiber.StatusTeapot},
{" Basic " + creds, fiber.StatusTeapot},
{"Basic " + creds, fiber.StatusBadRequest},
{"Basic " + creds, fiber.StatusBadRequest},
{"Basic\t" + creds, fiber.StatusBadRequest},
{"Basic \t" + creds, fiber.StatusBadRequest},
{"Basic\u00A0" + creds, fiber.StatusBadRequest},
{"Basic\u3000" + creds, fiber.StatusBadRequest},
{"\tBasic " + creds + "\t", fiber.StatusTeapot},
{"Basic " + creds[:4] + " " + creds[4:], fiber.StatusBadRequest},
}
for _, tt := range cases {
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, tt.header)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, tt.status, resp.StatusCode)
}
}
func Test_BasicAuth_ControlChars(t *testing.T) {
t.Parallel()
called := false
app := fiber.New()
app.Use(New(Config{
Authorizer: func(_, _ string, _ fiber.Ctx) bool {
called = true
return true
},
}))
creds := []string{
base64.StdEncoding.EncodeToString([]byte("john:\x01doe")),
base64.StdEncoding.EncodeToString([]byte("jo\x7Fhn:doe")),
base64.StdEncoding.EncodeToString([]byte{'j', 'o', 'h', 'n', ':', 0x85, 'd', 'o', 'e'}),
base64.StdEncoding.EncodeToString([]byte{'j', 'o', 'h', 'n', ':', 0x9F, 'd', 'o', 'e'}),
}
for _, c := range creds {
called = false
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Basic "+c)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
require.Empty(t, resp.Header.Get(fiber.HeaderWWWAuthenticate))
require.False(t, called)
}
}
func Test_BasicAuth_UnpaddedBase64(t *testing.T) {
t.Parallel()
app := fiber.New()
hashedJohn := sha256Hash("doe")
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}}))
app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) })
creds := base64.StdEncoding.EncodeToString([]byte("john:doe"))
creds = strings.TrimRight(creds, "=")
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
}
func Test_BasicAuth_NonASCIIHeader(t *testing.T) {
t.Parallel()
app := fiber.New()
hashedJohn := sha256Hash("doe")
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}}))
handler := app.Handler()
creds := base64.StdEncoding.EncodeToString([]byte("john:doe"))
fctx := &fasthttp.RequestCtx{}
fctx.Request.SetRequestURI("/")
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.Header.SetBytesKV([]byte(fiber.HeaderAuthorization), []byte("Basic \x80"+creds))
handler(fctx)
require.Equal(t, fiber.StatusBadRequest, fctx.Response.StatusCode())
}
func Test_BasicAuth_InvalidUTF8(t *testing.T) {
t.Parallel()
called := false
app := fiber.New()
app.Use(New(Config{
Charset: "UTF-8",
Authorizer: func(_, _ string, _ fiber.Ctx) bool {
called = true
return true
},
}))
creds := base64.StdEncoding.EncodeToString([]byte{'j', 'o', 'h', 'n', ':', 0xff, 'd', 'o', 'e'})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
require.False(t, called)
}
func Test_BasicAuth_UTF8Normalization(t *testing.T) {
t.Parallel()
app := fiber.New()
decomposed := "e\u0301" // e + combining acute accent
called := false
app.Use(New(Config{
Charset: "UTF-8",
Authorizer: func(u, p string, _ fiber.Ctx) bool {
called = true
require.Equal(t, "é", u)
require.Equal(t, "doe", p)
return true
},
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) })
creds := base64.StdEncoding.EncodeToString([]byte(decomposed + ":doe"))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
require.True(t, called)
}
func Test_BasicAuth_HeaderControlCharEdges(t *testing.T) {
t.Parallel()
app := fiber.New()
hashedJohn := sha256Hash("doe")
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}}))
handler := app.Handler()
creds := base64.StdEncoding.EncodeToString([]byte("john:doe"))
headers := [][]byte{
[]byte("\rBasic " + creds),
[]byte("\nBasic " + creds),
[]byte("Basic " + creds + "\r"),
[]byte("Basic " + creds + "\n"),
}
for _, h := range headers {
fctx := &fasthttp.RequestCtx{}
fctx.Request.SetRequestURI("/")
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.Header.SetBytesKV([]byte(fiber.HeaderAuthorization), h)
handler(fctx)
require.Equal(t, fiber.StatusBadRequest, fctx.Response.StatusCode())
}
}
func Test_BasicAuth_Charset(t *testing.T) {
t.Parallel()
require.Panics(t, func() { New(Config{Charset: "ISO-8859-1"}) })
require.NotPanics(t, func() { New(Config{Charset: "utf-8"}) })
require.NotPanics(t, func() { New(Config{Charset: "UTF-8"}) })
require.NotPanics(t, func() { New(Config{}) })
}
func Test_BasicAuth_HeaderLimit(t *testing.T) {
t.Parallel()
creds := base64.StdEncoding.EncodeToString([]byte("john:doe"))
hashedJohn := sha256Hash("doe")
t.Run("too large", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}, HeaderLimit: 10}))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusRequestHeaderFieldsTooLarge, resp.StatusCode)
})
t.Run("allowed", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}, HeaderLimit: 100}))
app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
})
}
// go test -v -run=^$ -bench=Benchmark_Middleware_BasicAuth -benchmem -count=4
func Benchmark_Middleware_BasicAuth(b *testing.B) {
app := fiber.New()
hashedJohn := sha256Hash("doe")
app.Use(New(Config{
Users: map[string]string{
"john": hashedJohn,
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
fctx.Request.Header.Set(fiber.HeaderAuthorization, "basic am9objpkb2U=") // john:doe
b.ReportAllocs()
for b.Loop() {
h(fctx)
}
require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
}
// go test -v -run=^$ -bench=Benchmark_Middleware_BasicAuth -benchmem -count=4
func Benchmark_Middleware_BasicAuth_Upper(b *testing.B) {
app := fiber.New()
hashedJohn := sha256Hash("doe")
app.Use(New(Config{
Users: map[string]string{
"john": hashedJohn,
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
fctx.Request.Header.Set(fiber.HeaderAuthorization, "Basic am9objpkb2U=") // john:doe
b.ReportAllocs()
for b.Loop() {
h(fctx)
}
require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
}
func Test_BasicAuth_Immutable(t *testing.T) {
t.Parallel()
app := fiber.New(fiber.Config{Immutable: true})
hashedJohn := sha256Hash("doe")
app.Use(New(Config{Users: map[string]string{"john": hashedJohn}}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
creds := base64.StdEncoding.EncodeToString([]byte("john:doe"))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
}
func Test_parseHashedPassword(t *testing.T) {
t.Parallel()
pass := "secret"
sha := sha256.Sum256([]byte(pass))
b64 := base64.StdEncoding.EncodeToString(sha[:])
hexDigest := hex.EncodeToString(sha[:])
bcryptHash, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.MinCost)
require.NoError(t, err)
cases := []struct {
name string
hashed string
}{
{"bcrypt", string(bcryptHash)},
{"sha512", sha512Hash(pass)},
{"sha256", sha256Hash(pass)},
{"sha256-hex", hexDigest},
{"sha256-b64", b64},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
verify, err := parseHashedPassword(tt.hashed)
require.NoError(t, err)
require.True(t, verify(pass))
require.False(t, verify("wrong"))
})
}
}
func Test_BasicAuth_HashVariants(t *testing.T) {
t.Parallel()
pass := "doe"
bcryptHash, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.MinCost)
require.NoError(t, err)
cases := []struct {
name string
hashed string
}{
{"bcrypt", string(bcryptHash)},
{"sha512", sha512Hash(pass)},
{"sha256", sha256Hash(pass)},
{"sha256-hex", func() string { h := sha256.Sum256([]byte(pass)); return hex.EncodeToString(h[:]) }()},
}
for _, tt := range cases {
app := fiber.New()
app.Use(New(Config{Users: map[string]string{"john": tt.hashed}}))
app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) })
creds := base64.StdEncoding.EncodeToString([]byte("john:" + pass))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
}
}
func Test_BasicAuth_HashVariants_Invalid(t *testing.T) {
t.Parallel()
pass := "doe"
wrong := "wrong"
bcryptHash, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.MinCost)
require.NoError(t, err)
cases := []struct {
name string
hashed string
}{
{"bcrypt", string(bcryptHash)},
{"sha512", sha512Hash(pass)},
{"sha256", sha256Hash(pass)},
{"sha256-hex", func() string { h := sha256.Sum256([]byte(pass)); return hex.EncodeToString(h[:]) }()},
}
for _, tt := range cases {
app := fiber.New()
app.Use(New(Config{Users: map[string]string{"john": tt.hashed}}))
app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) })
creds := base64.StdEncoding.EncodeToString([]byte("john:" + wrong))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode)
}
}
================================================
FILE: middleware/basicauth/config.go
================================================
package basicauth
import (
"crypto/sha256"
"crypto/sha512"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"strconv"
"strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
"golang.org/x/crypto/bcrypt"
)
var ErrInvalidSHA256PasswordLength = errors.New("decode SHA256 password: invalid length")
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// Users defines the allowed credentials
//
// Required. Default: map[string]string{}
Users map[string]string
// Authorizer defines a function you can pass
// to check the credentials however you want.
// It will be called with a username, password and
// the current fiber context and is expected to return
// true or false to indicate that the credentials were
// approved or not.
//
// Optional. Default: nil.
Authorizer func(string, string, fiber.Ctx) bool
// Unauthorized defines the response body for unauthorized responses.
// By default it will return with a 401 Unauthorized and the correct WWW-Auth header
//
// Optional. Default: nil
Unauthorized fiber.Handler
// BadRequest defines the response body for malformed Authorization headers.
// By default it will return with a 400 Bad Request without the WWW-Authenticate header.
//
// Optional. Default: nil
BadRequest fiber.Handler
// Realm is a string to define realm attribute of BasicAuth.
// the realm identifies the system to authenticate against
// and can be used by clients to save credentials
//
// Optional. Default: "Restricted".
Realm string
// Charset defines the value for the charset parameter in the
// WWW-Authenticate header. According to RFC 7617 clients can use
// this value to interpret credentials correctly. Only the value
// "UTF-8" is allowed; any other value will panic.
//
// Optional. Default: "UTF-8".
Charset string
// HeaderLimit specifies the maximum allowed length of the
// Authorization header. Requests exceeding this limit will
// be rejected.
//
// Optional. Default: 8192.
HeaderLimit int
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Users: map[string]string{},
Realm: "Restricted",
Charset: "UTF-8",
HeaderLimit: 8192,
Authorizer: nil,
Unauthorized: nil,
BadRequest: nil,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Users == nil {
cfg.Users = ConfigDefault.Users
}
if cfg.Realm == "" {
cfg.Realm = ConfigDefault.Realm
}
switch {
case cfg.Charset == "":
cfg.Charset = ConfigDefault.Charset
case utils.EqualFold(cfg.Charset, "UTF-8"):
cfg.Charset = "UTF-8"
default:
panic("basicauth: charset must be UTF-8")
}
if cfg.HeaderLimit <= 0 {
cfg.HeaderLimit = ConfigDefault.HeaderLimit
}
if cfg.Authorizer == nil {
verifiers := make(map[string]func(string) bool, len(cfg.Users))
for u, hpw := range cfg.Users {
v, err := parseHashedPassword(hpw)
if err != nil {
panic(err)
}
verifiers[u] = v
}
cfg.Authorizer = func(user, pass string, _ fiber.Ctx) bool {
verify, ok := verifiers[user]
return ok && verify(pass)
}
}
if cfg.Unauthorized == nil {
cfg.Unauthorized = func(c fiber.Ctx) error {
header := "Basic realm=" + strconv.Quote(cfg.Realm)
if cfg.Charset != "" {
header += ", charset=" + strconv.Quote(cfg.Charset)
}
c.Set(fiber.HeaderWWWAuthenticate, header)
c.Set(fiber.HeaderCacheControl, "no-store")
c.Set(fiber.HeaderVary, fiber.HeaderAuthorization)
return c.SendStatus(fiber.StatusUnauthorized)
}
}
if cfg.BadRequest == nil {
cfg.BadRequest = func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusBadRequest)
}
}
return cfg
}
func parseHashedPassword(h string) (func(string) bool, error) {
switch {
case strings.HasPrefix(h, "$2"):
hash := []byte(h)
return func(p string) bool {
return bcrypt.CompareHashAndPassword(hash, []byte(p)) == nil
}, nil
case strings.HasPrefix(h, "{SHA512}"):
b, err := base64.StdEncoding.DecodeString(h[len("{SHA512}"):])
if err != nil {
return nil, fmt.Errorf("decode SHA512 password: %w", err)
}
return func(p string) bool {
sum := sha512.Sum512([]byte(p))
return subtle.ConstantTimeCompare(sum[:], b) == 1
}, nil
case strings.HasPrefix(h, "{SHA256}"):
b, err := base64.StdEncoding.DecodeString(h[len("{SHA256}"):])
if err != nil {
return nil, fmt.Errorf("decode SHA256 password: %w", err)
}
return func(p string) bool {
sum := sha256.Sum256([]byte(p))
return subtle.ConstantTimeCompare(sum[:], b) == 1
}, nil
default:
b, err := hex.DecodeString(h)
if err != nil || len(b) != sha256.Size {
if b, err = base64.StdEncoding.DecodeString(h); err != nil {
return nil, fmt.Errorf("decode SHA256 password: %w", err)
}
if len(b) != sha256.Size {
return nil, ErrInvalidSHA256PasswordLength
}
}
return func(p string) bool {
sum := sha256.Sum256([]byte(p))
return subtle.ConstantTimeCompare(sum[:], b) == 1
}, nil
}
}
================================================
FILE: middleware/cache/cache.go
================================================
// Special thanks to @codemicro for moving this to fiber core
// Original middleware: github.com/codemicro/fiber-cache
package cache
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"math"
"slices"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/utils/v2"
utilsstrings "github.com/gofiber/utils/v2/strings"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v3"
)
// timestampUpdatePeriod is the period which is used to check the cache expiration.
// It should not be too long to provide more or less acceptable expiration error, and in the same
// time it should not be too short to avoid overwhelming of the system
const timestampUpdatePeriod = 300 * time.Millisecond
// buffer size for hexpool
const hexLen = sha256.Size * 2
// cache status
// unreachable: when cache is bypass, or invalid
// hit: cache is served
// miss: do not have cache record
const (
cacheUnreachable = "unreachable"
cacheHit = "hit"
cacheMiss = "miss"
)
type expirationSource uint8
const (
expirationSourceConfig expirationSource = iota
expirationSourceMaxAge
expirationSourceSMaxAge
expirationSourceExpires
expirationSourceGenerator
)
// directives
const (
noCache = "no-cache"
noStore = "no-store"
privateDirective = "private"
)
type requestCacheDirectives struct {
maxAge uint64
maxStale uint64
minFresh uint64
maxAgeSet bool
maxStaleSet bool
maxStaleAny bool
minFreshSet bool
noStore bool
noCache bool
onlyIfCached bool
}
var ignoreHeaders = map[string]struct{}{
"Age": {},
"Cache-Control": {}, // already stored explicitly by the cache manager
"Connection": {},
"Content-Encoding": {}, // already stored explicitly by the cache manager
"Content-Type": {}, // already stored explicitly by the cache manager
"Date": {},
"ETag": {}, // already stored explicitly by the cache manager
"Expires": {}, // already stored explicitly by the cache manager
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"TE": {},
"Trailers": {},
"Transfer-Encoding": {},
"Upgrade": {},
}
var cacheableStatusCodes = map[int]struct{}{
fiber.StatusOK: {},
fiber.StatusNonAuthoritativeInformation: {},
fiber.StatusNoContent: {},
fiber.StatusPartialContent: {},
fiber.StatusMultipleChoices: {},
fiber.StatusMovedPermanently: {},
fiber.StatusPermanentRedirect: {},
fiber.StatusNotFound: {},
fiber.StatusMethodNotAllowed: {},
fiber.StatusGone: {},
fiber.StatusRequestURITooLong: {},
fiber.StatusNotImplemented: {},
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
type evictionCandidate struct {
key string
size uint
exp uint64
heapIdx int
}
redactKeys := !cfg.DisableValueRedaction
maskKey := func(key string) string {
if redactKeys {
return redactedKey
}
return key
}
// Nothing to cache
if int(cfg.Expiration.Seconds()) < 0 {
return func(c fiber.Ctx) error {
return c.Next()
}
}
var (
// Cache settings
mux = &sync.RWMutex{}
timestamp = safeUnixSeconds(time.Now())
)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage, redactKeys)
// Create indexed heap for tracking expirations ( see heap.go )
heap := &indexedHeap{}
// count stored bytes (sizes of response bodies)
var storedBytes uint
// Pool for hex encoding buffers
hexBufPool := &sync.Pool{
New: func() any {
buf := make([]byte, hexLen)
return &buf
},
}
hashAuthorization := makeHashAuthFunc(hexBufPool)
buildVaryKey := makeBuildVaryKeyFunc(hexBufPool)
// Update timestamp in the configured interval
go func() {
ticker := time.NewTicker(timestampUpdatePeriod)
defer ticker.Stop()
for range ticker.C {
atomic.StoreUint64(×tamp, safeUnixSeconds(time.Now()))
}
}()
// Delete key from both manager and storage
deleteKey := func(ctx context.Context, dkey string) error {
if err := manager.del(ctx, dkey); err != nil {
return err
}
// External storage saves body data with different key
if cfg.Storage != nil {
if err := manager.del(ctx, dkey+"_body"); err != nil {
return err
}
}
return nil
}
removeHeapEntry := func(entryKey string, heapIdx int) {
if cfg.MaxBytes == 0 {
return
}
if heapIdx < 0 || heapIdx >= len(heap.indices) {
return
}
indexedIdx := heap.indices[heapIdx]
if indexedIdx < 0 || indexedIdx >= len(heap.entries) {
return
}
entry := heap.entries[indexedIdx]
if entry.idx != heapIdx || entry.key != entryKey {
return
}
_, size := heap.remove(heapIdx)
storedBytes -= size
}
refreshHeapIndex := func(ctx context.Context, candidate evictionCandidate) error {
entry, err := manager.get(ctx, candidate.key)
if err != nil {
if errors.Is(err, errCacheMiss) {
return nil
}
return fmt.Errorf("cache: failed to reload key %q after eviction failure: %w", maskKey(candidate.key), err)
}
entry.heapidx = candidate.heapIdx
remainingTTL := max(time.Until(secondsToTime(entry.exp)), 0)
if err := manager.set(ctx, candidate.key, entry, remainingTTL); err != nil {
return fmt.Errorf("cache: failed to restore heap index for key %q: %w", maskKey(candidate.key), err)
}
return nil
}
// Return new handler
return func(c fiber.Ctx) error {
hasAuthorization := len(c.Request().Header.Peek(fiber.HeaderAuthorization)) > 0
reqCacheControl := c.Request().Header.Peek(fiber.HeaderCacheControl)
reqDirectives := parseRequestCacheControl(reqCacheControl)
if !reqDirectives.noCache {
reqPragma := utils.UnsafeString(c.Request().Header.Peek(fiber.HeaderPragma))
if hasDirective(reqPragma, noCache) {
reqDirectives.noCache = true
}
}
// Refrain from caching
if reqDirectives.noStore {
return c.Next()
}
requestMethod := c.Method()
// Only cache selected methods
if !slices.Contains(cfg.Methods, requestMethod) {
c.Set(cfg.CacheHeader, cacheUnreachable)
return c.Next()
}
// Get key from request
baseKey := cfg.KeyGenerator(c) + "_" + requestMethod
manifestKey := baseKey + "|vary"
if hasAuthorization {
authHash := hashAuthorization(c.Request().Header.Peek(fiber.HeaderAuthorization))
baseKey += "_auth_" + authHash
manifestKey = baseKey + "|vary"
}
key := baseKey
reqCtx := c.Context()
varyNames, hasVaryManifest, err := loadVaryManifest(reqCtx, manager, manifestKey)
if err != nil {
return err
}
if len(varyNames) > 0 {
key += buildVaryKey(varyNames, &c.Request().Header)
}
// Get entry from pool
e, err := manager.get(reqCtx, key)
if err != nil && !errors.Is(err, errCacheMiss) {
return err
}
entryAge := uint64(0)
revalidate := false
oldHeapIdx := -1 // Track old heap index for replacement during revalidation
handleMinFresh := func(now uint64) {
if e == nil || !reqDirectives.minFreshSet {
return
}
remainingFreshness := remainingFreshness(e, now)
if remainingFreshness < reqDirectives.minFresh {
revalidate = true
oldHeapIdx = e.heapidx
if cfg.Storage != nil {
manager.release(e)
}
e = nil
}
}
// Lock entry
mux.Lock()
locked := true
unlock := func() {
if locked {
mux.Unlock()
locked = false
}
}
relock := func() {
if !locked {
mux.Lock()
locked = true
}
}
// Get timestamp
ts := atomic.LoadUint64(×tamp)
// Cache Entry found
if e != nil {
entryAge = cachedResponseAge(e, ts)
if reqDirectives.maxAgeSet && (reqDirectives.maxAge == 0 || entryAge > reqDirectives.maxAge) {
revalidate = true
oldHeapIdx = e.heapidx
if cfg.Storage != nil {
manager.release(e)
}
e = nil
}
handleMinFresh(ts)
}
if e != nil && e.ttl == 0 && e.forceRevalidate {
revalidate = true
oldHeapIdx = e.heapidx
if cfg.Storage != nil {
manager.release(e)
}
e = nil
}
if e != nil && e.ttl == 0 && e.exp != 0 && ts >= e.exp {
unlock()
if err := deleteKey(reqCtx, key); err != nil {
if cfg.Storage != nil {
manager.release(e)
}
return fmt.Errorf("cache: failed to delete expired key %q: %w", maskKey(key), err)
}
relock()
removeHeapEntry(key, e.heapidx)
if cfg.Storage != nil {
manager.release(e)
}
e = nil
unlock()
c.Set(cfg.CacheHeader, cacheUnreachable)
goto continueRequest
}
if e != nil {
entryHasPrivate := e != nil && e.private
if !entryHasPrivate && cfg.StoreResponseHeaders && len(e.headers) > 0 {
if cc, ok := lookupCachedHeader(e.headers, fiber.HeaderCacheControl); ok && hasDirective(utils.UnsafeString(cc), privateDirective) {
entryHasPrivate = true
}
}
requestNoCache := reqDirectives.noCache
// Invalidate cache if requested
if cfg.CacheInvalidator != nil && cfg.CacheInvalidator(c) {
e.exp = ts - 1
}
entryHasExpiration := e != nil && e.exp != 0
entryExpired := entryHasExpiration && ts >= e.exp
staleness := uint64(0)
if entryExpired {
staleness = ts - e.exp
}
allowStale := entryExpired && (reqDirectives.maxStaleAny || (reqDirectives.maxStaleSet && staleness <= reqDirectives.maxStale))
if entryExpired && e.revalidate {
revalidate = true
oldHeapIdx = e.heapidx
if cfg.Storage != nil {
manager.release(e)
}
e = nil
}
handleMinFresh(ts)
if revalidate {
unlock()
c.Set(cfg.CacheHeader, cacheUnreachable)
if reqDirectives.onlyIfCached {
return c.SendStatus(fiber.StatusGatewayTimeout)
}
goto continueRequest
}
servedStale := false
switch {
case entryExpired && !allowStale:
unlock()
if err := deleteKey(reqCtx, key); err != nil {
if e != nil {
manager.release(e)
}
return fmt.Errorf("cache: failed to delete expired key %q: %w", maskKey(key), err)
}
relock()
idx := e.heapidx
manager.release(e)
removeHeapEntry(key, idx)
e = nil
case entryHasPrivate:
unlock()
if err := deleteKey(reqCtx, key); err != nil {
if e != nil {
manager.release(e)
}
return fmt.Errorf("cache: failed to delete private response for key %q: %w", maskKey(key), err)
}
relock()
removeHeapEntry(key, e.heapidx)
if cfg.Storage != nil && e != nil {
manager.release(e)
}
e = nil
unlock()
c.Set(cfg.CacheHeader, cacheUnreachable)
if reqDirectives.onlyIfCached {
return c.SendStatus(fiber.StatusGatewayTimeout)
}
return c.Next()
case entryHasExpiration && !requestNoCache:
servedStale = entryExpired
if hasAuthorization && !e.shareable {
if cfg.Storage != nil {
manager.release(e)
}
unlock()
c.Set(cfg.CacheHeader, cacheUnreachable)
return c.Next()
}
// Separate body value to avoid msgp serialization
// We can store raw bytes with Storage 👍
if cfg.Storage != nil {
unlock()
rawBody, err := manager.getRaw(reqCtx, key+"_body")
if err != nil {
manager.release(e)
return cacheBodyFetchError(maskKey, key, err)
}
e.body = rawBody
} else {
unlock()
}
// Set response headers from cache
c.Response().SetBodyRaw(e.body)
c.Response().SetStatusCode(e.status)
c.Response().Header.SetContentTypeBytes(e.ctype)
if len(e.cencoding) > 0 {
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
}
if len(e.cacheControl) > 0 {
c.Response().Header.SetBytesV(fiber.HeaderCacheControl, e.cacheControl)
}
if len(e.expires) > 0 {
c.Response().Header.SetBytesV(fiber.HeaderExpires, e.expires)
}
if len(e.etag) > 0 {
c.Response().Header.SetBytesV(fiber.HeaderETag, e.etag)
}
clampedDate := clampDateSeconds(e.date, ts)
dateValue := fasthttp.AppendHTTPDate(nil, secondsToTime(clampedDate))
c.Response().Header.SetBytesV(fiber.HeaderDate, dateValue)
for i := range e.headers {
h := e.headers[i]
c.Response().Header.SetBytesKV(h.key, h.value)
}
// Set Cache-Control header if not disabled and not already set
if !cfg.DisableCacheControl && len(c.Response().Header.Peek(fiber.HeaderCacheControl)) == 0 {
remaining := uint64(0)
if e.exp > ts {
remaining = e.exp - ts
}
maxAge := utils.FormatUint(remaining)
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
}
const maxDeltaSeconds = uint64(math.MaxInt32)
ageSeconds := min(entryAge, maxDeltaSeconds)
// RFC-compliant Age header (RFC 9111)
age := utils.FormatUint(ageSeconds)
c.Response().Header.Set(fiber.HeaderAge, age)
appendWarningHeaders(&c.Response().Header, servedStale, isHeuristicFreshness(e, &cfg, entryAge))
c.Set(cfg.CacheHeader, cacheHit)
// release item allocated from storage
if cfg.Storage != nil {
manager.release(e)
}
// Return response
return nil
default:
// no cached response to serve
}
}
if e == nil && revalidate {
unlock()
c.Set(cfg.CacheHeader, cacheUnreachable)
if reqDirectives.onlyIfCached {
return c.SendStatus(fiber.StatusGatewayTimeout)
}
goto continueRequest
}
if e == nil && reqDirectives.onlyIfCached {
unlock()
c.Set(cfg.CacheHeader, cacheUnreachable)
return c.SendStatus(fiber.StatusGatewayTimeout)
}
// make sure we're not blocking concurrent requests - do unlock
unlock()
continueRequest:
// Continue stack, return err to Fiber if exist
if err := c.Next(); err != nil {
return err
}
cacheControlBytes := c.Response().Header.Peek(fiber.HeaderCacheControl)
respCacheControl := parseResponseCacheControl(cacheControlBytes)
varyHeader := utils.UnsafeString(c.Response().Header.Peek(fiber.HeaderVary))
hasPrivate := respCacheControl.hasPrivate
hasNoCache := respCacheControl.hasNoCache
varyNames, varyHasStar := parseVary(varyHeader)
// Respect server cache-control: no-store
if respCacheControl.hasNoStore {
c.Set(cfg.CacheHeader, cacheUnreachable)
return nil
}
if hasPrivate || hasNoCache || varyHasStar {
if e != nil {
if err := deleteKey(reqCtx, key); err != nil {
if cfg.Storage != nil {
manager.release(e)
}
return fmt.Errorf("cache: failed to delete cached response for key %q: %w", maskKey(key), err)
}
mux.Lock()
removeHeapEntry(key, e.heapidx)
if cfg.Storage != nil {
manager.release(e)
}
e = nil
mux.Unlock()
}
if hasVaryManifest {
if err := manager.del(reqCtx, manifestKey); err != nil {
return fmt.Errorf("cache: failed to delete stale vary manifest %q: %w", maskKey(manifestKey), err)
}
}
c.Set(cfg.CacheHeader, cacheUnreachable)
return nil
}
shouldStoreVaryManifest := len(varyNames) > 0
if len(varyNames) > 0 {
if key == baseKey {
key += buildVaryKey(varyNames, &c.Request().Header)
}
} else if hasVaryManifest {
if err := manager.del(reqCtx, manifestKey); err != nil {
return fmt.Errorf("cache: failed to delete stale vary manifest %q: %w", maskKey(manifestKey), err)
}
}
isSharedCacheAllowed := allowsSharedCacheDirectives(respCacheControl)
if hasAuthorization && !isSharedCacheAllowed {
c.Set(cfg.CacheHeader, cacheUnreachable)
return nil
}
sharedCacheMode := !hasAuthorization || isSharedCacheAllowed
// Don't cache response if status code is not cacheable
if _, ok := cacheableStatusCodes[c.Response().StatusCode()]; !ok {
c.Set(cfg.CacheHeader, cacheUnreachable)
return nil
}
// Don't cache response if Next returns true
if cfg.Next != nil && cfg.Next(c) {
c.Set(cfg.CacheHeader, cacheUnreachable)
return nil
}
// Don't try to cache if body won't fit into cache
bodySize := uint(len(c.Response().Body()))
if cfg.MaxBytes > 0 && bodySize > cfg.MaxBytes {
c.Set(cfg.CacheHeader, cacheUnreachable)
return nil
}
// Eviction loop: atomically reserve space for new entry and evict old entries.
// Strategy:
// 1. Under lock: reserve space by pre-incrementing storedBytes, then collect entries to evict
// 2. Outside lock: perform I/O deletions
// 3. On deletion failure: restore storedBytes and return error
// 4. Track reservation with a flag; unreserve on early return via defer
var spaceReserved bool
defer func() {
// If we reserved space but the entry was not successfully added to heap, unreserve it
if cfg.MaxBytes > 0 && spaceReserved {
mux.Lock()
storedBytes -= bodySize
mux.Unlock()
}
}()
if cfg.MaxBytes > 0 {
mux.Lock()
// Reserve space for the new entry first
storedBytes += bodySize
spaceReserved = true
// Now evict entries until we're under the limit
var keysToRemove []string
var sizesToRemove []uint
var candidates []evictionCandidate
for storedBytes > cfg.MaxBytes {
if heap.Len() == 0 {
// Can't evict more, unreserve space and fail
storedBytes -= bodySize
// Set spaceReserved to false so the deferred cleanup does not unreserve again
spaceReserved = false
mux.Unlock()
return errors.New("cache: insufficient space and no entries to evict")
}
next := heap.entries[0]
keyToRemove, size := heap.removeFirst()
keysToRemove = append(keysToRemove, keyToRemove)
sizesToRemove = append(sizesToRemove, size)
candidates = append(candidates, evictionCandidate{
key: keyToRemove,
size: size,
exp: next.exp,
})
storedBytes -= size
}
mux.Unlock()
// Perform deletions outside the lock
if len(keysToRemove) > 0 {
for i, keyToRemove := range keysToRemove {
delErr := deleteKey(reqCtx, keyToRemove)
if delErr == nil {
continue
}
// Deletion failed: restore storedBytes for failed deletions
mux.Lock()
// Restore sizes of entries we failed to delete
for j := i; j < len(sizesToRemove); j++ {
storedBytes += sizesToRemove[j]
}
// Unreserve space for the new entry
storedBytes -= bodySize
spaceReserved = false
// Re-add entries to the heap to keep expiration tracking consistent
var restored []evictionCandidate
for j := i; j < len(candidates); j++ {
candidate := candidates[j]
candidate.heapIdx = heap.put(candidate.key, candidate.exp, candidate.size)
restored = append(restored, candidate)
}
mux.Unlock()
var restoreErr error
for _, candidate := range restored {
if err := refreshHeapIndex(reqCtx, candidate); err != nil {
restoreErr = errors.Join(restoreErr, err)
}
}
if restoreErr != nil {
return errors.Join(fmt.Errorf("cache: failed to delete key %q while evicting: %w", maskKey(keyToRemove), delErr), restoreErr)
}
return fmt.Errorf("cache: failed to delete key %q while evicting: %w", maskKey(keyToRemove), delErr)
}
}
}
e = manager.acquire()
// Cache response
e.body = utils.CopyBytes(c.Response().Body())
e.status = c.Response().StatusCode()
e.ctype = utils.CopyBytes(c.Response().Header.ContentType())
e.cencoding = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding))
e.private = false
e.cacheControl = utils.CopyBytes(cacheControlBytes)
e.expires = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderExpires))
e.etag = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderETag))
e.date = 0
ageVal := uint64(0)
if b := c.Response().Header.Peek(fiber.HeaderAge); len(b) > 0 {
if v, err := fasthttp.ParseUint(b); err == nil {
if v >= 0 {
ageVal = uint64(v)
}
}
} else {
c.Response().Header.Set(fiber.HeaderAge, "0")
}
e.age = ageVal
e.shareable = isSharedCacheAllowed
now := time.Now().UTC()
nowUnix := safeUnixSeconds(now)
dateHeader := c.Response().Header.Peek(fiber.HeaderDate)
parsedDate, _ := parseHTTPDate(dateHeader)
e.date = clampDateSeconds(parsedDate, nowUnix)
dateBytes := fasthttp.AppendHTTPDate(nil, secondsToTime(e.date))
c.Response().Header.SetBytesV(fiber.HeaderDate, dateBytes)
// Store all response headers
// (more: https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1)
if cfg.StoreResponseHeaders {
allHeaders := c.Response().Header.All()
e.headers = e.headers[:0]
for key, value := range allHeaders {
keyStr := string(key)
if _, ok := ignoreHeaders[keyStr]; ok {
continue
}
e.headers = append(e.headers, cachedHeader{
key: utils.CopyBytes(utils.UnsafeBytes(keyStr)),
value: utils.CopyBytes(value),
})
}
}
expirationSource := expirationSourceConfig
expiresParseError := false
mustRevalidate := respCacheControl.mustRevalidate || respCacheControl.proxyRevalidate
// default cache expiration
expiration := cfg.Expiration
if sharedCacheMode && respCacheControl.sMaxAgeSet {
expiration = secondsToDuration(respCacheControl.sMaxAge)
expirationSource = expirationSourceSMaxAge
}
if expirationSource == expirationSourceConfig {
if respCacheControl.maxAgeSet {
expiration = secondsToDuration(respCacheControl.maxAge)
expirationSource = expirationSourceMaxAge
} else if expiresBytes := c.Response().Header.Peek(fiber.HeaderExpires); len(expiresBytes) > 0 {
expiresAt, err := fasthttp.ParseHTTPDate(expiresBytes)
if err != nil {
expiration = time.Nanosecond
expiresParseError = true
} else {
expiration = time.Until(expiresAt)
}
expirationSource = expirationSourceExpires
}
}
// Calculate expiration by response header or other setting
if cfg.ExpirationGenerator != nil {
expiration = cfg.ExpirationGenerator(c, &cfg)
expirationSource = expirationSourceGenerator
}
e.forceRevalidate = expiresParseError
e.revalidate = mustRevalidate
storageExpiration := expiration
if expiresParseError || storageExpiration < cfg.Expiration {
storageExpiration = cfg.Expiration
}
if expiration <= 0 && !expiresParseError {
c.Set(cfg.CacheHeader, cacheUnreachable)
return nil
}
ts = atomic.LoadUint64(×tamp)
responseTS := max(ts, nowUnix)
maxAgeSeconds := uint64(time.Duration(math.MaxInt64) / time.Second)
var ageDuration time.Duration
apparentAge := e.age
if e.date > 0 && responseTS > e.date {
dateAge := responseTS - e.date
if dateAge > apparentAge {
apparentAge = dateAge
}
}
if expirationSource != expirationSourceExpires {
if apparentAge > maxAgeSeconds {
ageDuration = expiration + time.Second
} else {
ageDuration = time.Duration(apparentAge) * time.Second
}
}
remainingExpiration := expiration - ageDuration
if remainingExpiration <= 0 {
if expirationSource != expirationSourceExpires {
c.Set(cfg.CacheHeader, cacheUnreachable)
return nil
}
remainingExpiration = 0
}
if shouldStoreVaryManifest {
if err := storeVaryManifest(reqCtx, manager, manifestKey, varyNames, storageExpiration); err != nil {
return err
}
}
e.exp = responseTS + uint64(remainingExpiration.Seconds())
e.ttl = uint64(expiration.Seconds())
if expiresParseError {
e.exp = ts + 1
}
// Store entry in heap (space already reserved in eviction phase)
var heapIdx int
if cfg.MaxBytes > 0 {
mux.Lock()
heapIdx = heap.put(key, e.exp, bodySize)
e.heapidx = heapIdx
// Note: storedBytes was incremented during reservation, and evictions
// have already been accounted for, so no additional increment is needed
spaceReserved = false // Clear flag to prevent defer from unreserving
mux.Unlock()
}
cleanupOnStoreError := func(ctx context.Context, releaseEntry, rawStored bool) error {
var cleanupErr error
if cfg.MaxBytes > 0 {
mux.Lock()
_, size := heap.remove(heapIdx)
storedBytes -= size
mux.Unlock()
}
if releaseEntry {
manager.release(e)
}
if rawStored {
rawKey := key + "_body"
if err := manager.del(ctx, rawKey); err != nil {
cleanupErr = errors.Join(cleanupErr, fmt.Errorf("cache: failed to delete raw key %q after store error: %w", maskKey(rawKey), err))
}
}
return cleanupErr
}
// For external Storage we store raw body separated
if cfg.Storage != nil {
if err := manager.setRaw(reqCtx, key+"_body", e.body, storageExpiration); err != nil {
if cleanupErr := cleanupOnStoreError(reqCtx, true, false); cleanupErr != nil {
err = errors.Join(err, cleanupErr)
}
return err
}
// avoid body msgp encoding
e.body = nil
if err := manager.set(reqCtx, key, e, storageExpiration); err != nil {
if cleanupErr := cleanupOnStoreError(reqCtx, false, true); cleanupErr != nil {
err = errors.Join(err, cleanupErr)
}
return err
}
} else {
// Store entry in memory
if err := manager.set(reqCtx, key, e, storageExpiration); err != nil {
if cleanupErr := cleanupOnStoreError(reqCtx, true, false); cleanupErr != nil {
err = errors.Join(err, cleanupErr)
}
return err
}
}
// If revalidating, remove old heap entry now that replacement is successfully stored
if cfg.MaxBytes > 0 && revalidate && oldHeapIdx >= 0 {
mux.Lock()
removeHeapEntry(key, oldHeapIdx)
mux.Unlock()
}
c.Set(cfg.CacheHeader, cacheMiss)
// Finish response
return nil
}
}
// hasDirective checks if a cache-control header contains a directive (case-insensitive)
func hasDirective(cc, directive string) bool {
ccLen := len(cc)
dirLen := len(directive)
for i := 0; i <= ccLen-dirLen; i++ {
if !utils.EqualFold(cc[i:i+dirLen], directive) {
continue
}
if i > 0 {
prev := cc[i-1]
if prev != ' ' && prev != ',' {
continue
}
}
if i+dirLen == ccLen || cc[i+dirLen] == ',' {
return true
}
}
return false
}
func cacheBodyFetchError(mask func(string) string, key string, err error) error {
if errors.Is(err, errCacheMiss) {
return fmt.Errorf("cache: no cached body for key %q: %w", mask(key), err)
}
return err
}
func parseUintDirective(val []byte) (uint64, bool) {
if len(val) == 0 {
return 0, false
}
parsed, err := fasthttp.ParseUint(val)
if err != nil || parsed < 0 {
return 0, false
}
return uint64(parsed), true
}
func parseCacheControlDirectives(cc []byte, fn func(key, value []byte)) {
for i := 0; i < len(cc); {
// skip leading separators/spaces
for i < len(cc) && (cc[i] == ' ' || cc[i] == ',') {
i++
}
if i >= len(cc) {
break
}
start := i
for i < len(cc) && cc[i] != ',' {
i++
}
partEnd := i
for partEnd > start && cc[partEnd-1] == ' ' {
partEnd--
}
keyStart := start
for keyStart < partEnd && cc[keyStart] == ' ' {
keyStart++
}
if keyStart >= partEnd {
continue
}
keyEnd := keyStart
for keyEnd < partEnd && cc[keyEnd] != '=' {
keyEnd++
}
// Trim trailing spaces from key
keyEndTrimmed := keyEnd
for keyEndTrimmed > keyStart && cc[keyEndTrimmed-1] == ' ' {
keyEndTrimmed--
}
key := cc[keyStart:keyEndTrimmed]
var value []byte
if keyEnd < partEnd && cc[keyEnd] == '=' {
valueStart := keyEnd + 1
for valueStart < partEnd && cc[valueStart] == ' ' {
valueStart++
}
valueEnd := partEnd
for valueEnd > valueStart && cc[valueEnd-1] == ' ' {
valueEnd--
}
if valueStart <= valueEnd {
value = cc[valueStart:valueEnd]
// Handle quoted-string values per RFC 9111 Section 5.2
if len(value) >= 2 && value[0] == '"' && value[len(value)-1] == '"' {
value = unquoteCacheDirective(value)
}
}
}
fn(key, value)
i++ // skip comma
}
}
// unquoteCacheDirective removes quotes and handles escaped characters in quoted-string values.
// Per RFC 9111 Section 5.2, quoted-string values follow RFC 9110 Section 5.6.4.
func unquoteCacheDirective(quoted []byte) []byte {
if len(quoted) < 2 {
return quoted
}
// Remove surrounding quotes
inner := quoted[1 : len(quoted)-1]
// Check if there are any escaped characters (backslash followed by another character)
hasEscapes := false
for i := 0; i < len(inner)-1; i++ {
if inner[i] == '\\' {
hasEscapes = true
break
}
}
// If no escapes, return the inner content directly
if !hasEscapes {
return inner
}
// Process escaped characters
result := make([]byte, 0, len(inner))
for i := 0; i < len(inner); i++ {
if inner[i] == '\\' && i+1 < len(inner) {
// Skip the backslash and take the next character
i++
result = append(result, inner[i])
} else {
result = append(result, inner[i])
}
}
return result
}
type responseCacheControl struct {
maxAge uint64
sMaxAge uint64
maxAgeSet bool
sMaxAgeSet bool
hasNoCache bool
hasNoStore bool
hasPrivate bool
hasPublic bool
mustRevalidate bool
proxyRevalidate bool
}
func parseResponseCacheControl(cc []byte) responseCacheControl {
parsed := responseCacheControl{}
parseCacheControlDirectives(cc, func(key, value []byte) {
switch {
case utils.EqualFold(utils.UnsafeString(key), noStore):
parsed.hasNoStore = true
case utils.EqualFold(utils.UnsafeString(key), noCache):
parsed.hasNoCache = true
case utils.EqualFold(utils.UnsafeString(key), privateDirective):
parsed.hasPrivate = true
case utils.EqualFold(utils.UnsafeString(key), "public"):
parsed.hasPublic = true
case utils.EqualFold(utils.UnsafeString(key), "max-age"):
if v, ok := parseUintDirective(value); ok {
parsed.maxAgeSet = true
parsed.maxAge = v
}
case utils.EqualFold(utils.UnsafeString(key), "s-maxage"):
if v, ok := parseUintDirective(value); ok {
parsed.sMaxAgeSet = true
parsed.sMaxAge = v
}
case utils.EqualFold(utils.UnsafeString(key), "must-revalidate"):
parsed.mustRevalidate = true
case utils.EqualFold(utils.UnsafeString(key), "proxy-revalidate"):
parsed.proxyRevalidate = true
default:
// ignore unknown directives
}
})
return parsed
}
// parseMaxAge extracts the max-age directive from a Cache-Control header.
func parseMaxAge(cc string) (time.Duration, bool) {
parsed := parseResponseCacheControl(utils.UnsafeBytes(cc))
if !parsed.maxAgeSet {
return 0, false
}
return secondsToDuration(parsed.maxAge), true
}
func parseRequestCacheControl(cc []byte) requestCacheDirectives {
directives := requestCacheDirectives{}
parseCacheControlDirectives(cc, func(key, value []byte) {
switch {
case utils.EqualFold(utils.UnsafeString(key), noStore):
directives.noStore = true
case utils.EqualFold(utils.UnsafeString(key), noCache):
directives.noCache = true
case utils.EqualFold(utils.UnsafeString(key), "only-if-cached"):
directives.onlyIfCached = true
case utils.EqualFold(utils.UnsafeString(key), "max-age"):
if sec, ok := parseUintDirective(value); ok {
directives.maxAgeSet = true
directives.maxAge = sec
}
case utils.EqualFold(utils.UnsafeString(key), "max-stale"):
directives.maxStaleSet = true
directives.maxStaleAny = len(value) == 0
if !directives.maxStaleAny {
if sec, ok := parseUintDirective(value); ok {
directives.maxStale = sec
}
}
case utils.EqualFold(utils.UnsafeString(key), "min-fresh"):
if sec, ok := parseUintDirective(value); ok {
directives.minFreshSet = true
directives.minFresh = sec
}
default:
// ignore unknown directives
}
})
return directives
}
func parseRequestCacheControlString(cc string) requestCacheDirectives {
return parseRequestCacheControl(utils.UnsafeBytes(cc))
}
func cachedResponseAge(e *item, now uint64) uint64 {
clampedDate := clampDateSeconds(e.date, now)
resident := uint64(0)
if e.exp != 0 {
if e.exp <= now {
resident = e.ttl + (now - e.exp)
} else {
resident = e.ttl - (e.exp - now)
}
}
dateAge := uint64(0)
if clampedDate != 0 && now > clampedDate {
dateAge = now - clampedDate
}
currentAge := max(dateAge, max(resident, e.age))
return currentAge
}
func appendWarningHeaders(h *fasthttp.ResponseHeader, servedStale, heuristicFreshness bool) { //nolint:revive // flags are intentional to represent Warning variants
if servedStale {
h.Add(fiber.HeaderWarning, `110 - "Response is stale"`)
}
if heuristicFreshness {
h.Add(fiber.HeaderWarning, `113 - "Heuristic expiration"`)
}
}
func remainingFreshness(e *item, now uint64) uint64 {
if e == nil || e.exp == 0 || now >= e.exp {
return 0
}
return e.exp - now
}
func isHeuristicFreshness(e *item, cfg *Config, entryAge uint64) bool {
const heuristicAgeThresholdSeconds = uint64(24 * time.Hour / time.Second)
if entryAge <= heuristicAgeThresholdSeconds {
return false
}
if len(e.expires) > 0 {
return false
}
cacheControl := utils.UnsafeString(e.cacheControl)
if parsedCC := parseResponseCacheControl(utils.UnsafeBytes(cacheControl)); parsedCC.maxAgeSet || parsedCC.sMaxAgeSet {
return false
}
return cfg.Expiration > 0
}
func lookupCachedHeader(headers []cachedHeader, name string) ([]byte, bool) {
for i := range headers {
if utils.EqualFold(utils.UnsafeString(headers[i].key), name) {
return headers[i].value, true
}
}
return nil, false
}
func parseHTTPDate(dateBytes []byte) (uint64, bool) {
if len(dateBytes) == 0 {
return 0, false
}
parsedDate, err := fasthttp.ParseHTTPDate(dateBytes)
if err != nil {
return 0, false
}
return safeUnixSeconds(parsedDate), true
}
func clampDateSeconds(dateSeconds, fallback uint64) uint64 {
const maxUnixSeconds = uint64(math.MaxInt64)
if dateSeconds == 0 || dateSeconds > maxUnixSeconds || dateSeconds > fallback {
return fallback
}
return dateSeconds
}
func safeUnixSeconds(t time.Time) uint64 {
sec := t.Unix()
if sec < 0 {
return 0
}
return uint64(sec)
}
func secondsToTime(sec uint64) time.Time {
var clamped int64
if sec > uint64(math.MaxInt64) {
clamped = math.MaxInt64
} else {
clamped = int64(sec)
}
return time.Unix(clamped, 0).UTC()
}
func secondsToDuration(sec uint64) time.Duration {
const maxSeconds = uint64(math.MaxInt64) / uint64(time.Second)
if sec > maxSeconds {
return time.Duration(math.MaxInt64)
}
return time.Duration(sec) * time.Second
}
func parseVary(vary string) ([]string, bool) {
names := make([]string, 0, 8)
for part := range strings.SplitSeq(vary, ",") {
name := utils.TrimSpace(utilsstrings.ToLower(part))
if name == "" {
continue
}
if name == "*" {
return nil, true
}
names = append(names, name)
}
if len(names) == 0 {
return nil, false
}
sort.Strings(names)
return names, false
}
func makeBuildVaryKeyFunc(hexBufPool *sync.Pool) func([]string, *fasthttp.RequestHeader) string {
return func(names []string, hdr *fasthttp.RequestHeader) string {
sum := sha256.New()
for _, name := range names {
_, _ = sum.Write(utils.UnsafeBytes(name)) //nolint:errcheck // hash.Hash.Write for std hashes never errors
_, _ = sum.Write([]byte{0}) //nolint:errcheck // hash.Hash.Write for std hashes never errors
_, _ = sum.Write(hdr.Peek(name)) //nolint:errcheck // hash.Hash.Write for std hashes never errors
_, _ = sum.Write([]byte{0}) //nolint:errcheck // hash.Hash.Write for std hashes never errors
}
var hashBytes [sha256.Size]byte
sum.Sum(hashBytes[:0])
v := hexBufPool.Get()
bufPtr, ok := v.(*[]byte)
if !ok || bufPtr == nil {
b := make([]byte, hexLen)
bufPtr = &b
}
buf := *bufPtr
// Defensive in case someone changed Pool.New or Put a different sized buffer.
if cap(buf) < hexLen {
buf = make([]byte, hexLen)
} else {
buf = buf[:hexLen]
}
*bufPtr = buf
hex.Encode(buf, hashBytes[:])
result := "|vary|" + string(buf)
hexBufPool.Put(bufPtr)
return result
}
}
func storeVaryManifest(ctx context.Context, manager *manager, manifestKey string, names []string, exp time.Duration) error {
if len(names) == 0 {
return nil
}
data := strings.Join(names, ",")
return manager.setRaw(ctx, manifestKey, utils.UnsafeBytes(data), exp)
}
//nolint:gocritic // returning explicit values keeps the signature concise while avoiding unnecessary named results
func loadVaryManifest(ctx context.Context, manager *manager, manifestKey string) ([]string, bool, error) {
raw, err := manager.getRaw(ctx, manifestKey)
if err != nil {
if errors.Is(err, errCacheMiss) {
return nil, false, nil
}
return nil, false, err
}
manifest := utils.UnsafeString(raw)
names, hasStar := parseVary(manifest)
if hasStar {
return nil, false, nil
}
return names, len(names) > 0, nil
}
func allowsSharedCacheDirectives(cc responseCacheControl) bool {
if cc.hasPrivate {
return false
}
if cc.hasPublic || cc.sMaxAgeSet || cc.mustRevalidate || cc.proxyRevalidate {
return true
}
// RFC 9111 §4.2.2 permits Expires as an absolute expiry for cacheable responses, but for
// authenticated requests §3.6 requires an explicit shared-cache directive. Therefore,
// an Expires header alone MUST NOT allow sharing when Authorization is present.
return false
}
func allowsSharedCache(cc string) bool {
return allowsSharedCacheDirectives(parseResponseCacheControl(utils.UnsafeBytes(cc)))
}
func makeHashAuthFunc(hexBufPool *sync.Pool) func([]byte) string {
return func(authHeader []byte) string {
sum := sha256.Sum256(authHeader)
v := hexBufPool.Get()
bufPtr, ok := v.(*[]byte)
if !ok || bufPtr == nil {
b := make([]byte, hexLen)
bufPtr = &b
}
buf := *bufPtr
if cap(buf) < hexLen {
buf = make([]byte, hexLen)
} else {
buf = buf[:hexLen]
}
*bufPtr = buf
hex.Encode(buf, sum[:])
result := string(buf)
hexBufPool.Put(bufPtr)
return result
}
}
================================================
FILE: middleware/cache/cache_test.go
================================================
// Special thanks to @codemicro for moving this to fiber core
// Original middleware: github.com/codemicro/fiber-cache
package cache
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math"
"net/http"
"net/http/httptest"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/storage/memory"
"github.com/gofiber/fiber/v3/middleware/etag"
"github.com/gofiber/utils/v2"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
type failingCacheStorage struct {
data map[string][]byte
errs map[string]error
mu sync.RWMutex
}
type mutatingStorage struct {
data map[string][]byte
mutate func(key string, value []byte) []byte
}
func newFailingCacheStorage() *failingCacheStorage {
return &failingCacheStorage{
data: make(map[string][]byte),
errs: make(map[string]error),
}
}
func newMutatingStorage(mutate func(key string, value []byte) []byte) *mutatingStorage {
return &mutatingStorage{
data: make(map[string][]byte),
mutate: mutate,
}
}
func (s *mutatingStorage) GetWithContext(_ context.Context, key string) ([]byte, error) {
return s.Get(key)
}
func (s *mutatingStorage) Get(key string) ([]byte, error) {
if value, ok := s.data[key]; ok {
return value, nil
}
return nil, nil
}
func (s *mutatingStorage) SetWithContext(_ context.Context, key string, val []byte, _ time.Duration) error {
return s.Set(key, val, 0)
}
func (s *mutatingStorage) Set(key string, val []byte, _ time.Duration) error {
if key == "" || len(val) == 0 {
return nil
}
if s.mutate != nil {
val = s.mutate(key, val)
}
s.data[key] = val
return nil
}
func (s *mutatingStorage) DeleteWithContext(_ context.Context, key string) error {
return s.Delete(key)
}
func (s *mutatingStorage) Delete(key string) error {
delete(s.data, key)
return nil
}
func (s *mutatingStorage) ResetWithContext(_ context.Context) error {
return s.Reset()
}
func (s *mutatingStorage) Reset() error {
s.data = make(map[string][]byte)
return nil
}
func (s *mutatingStorage) Close() error {
s.data = nil
return nil
}
func (s *failingCacheStorage) GetWithContext(_ context.Context, key string) ([]byte, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if err, ok := s.errs["get|"+key]; ok && err != nil {
return nil, err
}
if val, ok := s.data[key]; ok {
return append([]byte(nil), val...), nil
}
return nil, nil
}
func (s *failingCacheStorage) Get(key string) ([]byte, error) {
return s.GetWithContext(context.Background(), key)
}
func (s *failingCacheStorage) SetWithContext(_ context.Context, key string, val []byte, _ time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()
if err, ok := s.errs["set|"+key]; ok && err != nil {
return err
}
s.data[key] = append([]byte(nil), val...)
return nil
}
func (s *failingCacheStorage) Set(key string, val []byte, exp time.Duration) error {
return s.SetWithContext(context.Background(), key, val, exp)
}
func (s *failingCacheStorage) DeleteWithContext(_ context.Context, key string) error {
s.mu.Lock()
defer s.mu.Unlock()
if err, ok := s.errs["del|"+key]; ok && err != nil {
return err
}
delete(s.data, key)
return nil
}
func (s *failingCacheStorage) Delete(key string) error {
return s.DeleteWithContext(context.Background(), key)
}
func (s *failingCacheStorage) ResetWithContext(context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
s.data = make(map[string][]byte)
s.errs = make(map[string]error)
return nil
}
func (s *failingCacheStorage) Reset() error {
return s.ResetWithContext(context.Background())
}
func (*failingCacheStorage) Close() error { return nil }
type contextRecord struct {
key string
value string
canceled bool
}
type contextRecorderStorage struct {
*failingCacheStorage
deletes []contextRecord
gets []contextRecord
sets []contextRecord
}
func newContextRecorderStorage() *contextRecorderStorage {
return &contextRecorderStorage{failingCacheStorage: newFailingCacheStorage()}
}
func contextRecordFrom(ctx context.Context, key string) contextRecord {
record := contextRecord{
key: key,
canceled: errors.Is(ctx.Err(), context.Canceled),
}
if value, ok := ctx.Value(markerKey).(string); ok {
record.value = value
}
return record
}
func (s *contextRecorderStorage) GetWithContext(ctx context.Context, key string) ([]byte, error) {
s.gets = append(s.gets, contextRecordFrom(ctx, key))
return s.failingCacheStorage.GetWithContext(ctx, key)
}
func (s *contextRecorderStorage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error {
s.sets = append(s.sets, contextRecordFrom(ctx, key))
return s.failingCacheStorage.SetWithContext(ctx, key, val, exp)
}
func (s *contextRecorderStorage) DeleteWithContext(ctx context.Context, key string) error {
s.deletes = append(s.deletes, contextRecordFrom(ctx, key))
return s.failingCacheStorage.DeleteWithContext(ctx, key)
}
func (s *contextRecorderStorage) recordedGets() []contextRecord {
out := make([]contextRecord, len(s.gets))
copy(out, s.gets)
return out
}
func (s *contextRecorderStorage) recordedSets() []contextRecord {
out := make([]contextRecord, len(s.sets))
copy(out, s.sets)
return out
}
func (s *contextRecorderStorage) recordedDeletes() []contextRecord {
out := make([]contextRecord, len(s.deletes))
copy(out, s.deletes)
return out
}
func TestCacheStorageGetError(t *testing.T) {
t.Parallel()
storage := newFailingCacheStorage()
storage.errs["get|/_GET"] = errors.New("boom")
var captured error
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
captured = err
return c.Status(fiber.StatusInternalServerError).SendString("storage failure")
},
})
app.Use(New(Config{Storage: storage, Expiration: time.Second}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("ok")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Error(t, captured)
require.ErrorContains(t, captured, "cache: failed to get key")
}
func TestCacheStorageSetError(t *testing.T) {
t.Parallel()
storage := newFailingCacheStorage()
storage.errs["set|/_GET_body"] = errors.New("boom")
var captured error
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
captured = err
return c.Status(fiber.StatusInternalServerError).SendString("storage failure")
},
})
app.Use(New(Config{Storage: storage, Expiration: time.Second}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("ok")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Error(t, captured)
require.ErrorContains(t, captured, "cache: failed to store raw key")
}
func TestCacheStorageDeleteError(t *testing.T) {
t.Parallel()
storage := newFailingCacheStorage()
storage.errs["del|/_GET"] = errors.New("boom")
// Use an obviously expired timestamp without relying on time-based conversions
expired := &item{exp: 1}
raw, err := expired.MarshalMsg(nil)
require.NoError(t, err)
storage.data["/_GET"] = raw
var captured error
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
captured = err
return c.Status(fiber.StatusInternalServerError).SendString("storage failure")
},
})
app.Use(New(Config{Storage: storage, Expiration: time.Second}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("ok")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Error(t, captured)
require.ErrorContains(t, captured, "cache: failed to delete expired key")
}
type contextKey string
const markerKey contextKey = "marker"
func contextWithMarker(label string) context.Context {
return context.WithValue(context.Background(), markerKey, label)
}
func canceledContextWithMarker(label string) context.Context {
ctx, cancel := context.WithCancel(contextWithMarker(label))
cancel()
return ctx
}
func TestCacheEvictionPropagatesRequestContextToDelete(t *testing.T) {
t.Parallel()
storage := newContextRecorderStorage()
app := fiber.New()
app.Use(func(c fiber.Ctx) error {
path := c.Path()
if path == "/first" {
c.SetContext(contextWithMarker("first"))
}
if path == "/second" {
c.SetContext(canceledContextWithMarker("evict"))
}
return c.Next()
})
app.Use(New(Config{Storage: storage, Expiration: time.Minute, MaxBytes: 5}))
app.Get("/first", func(c fiber.Ctx) error {
return c.SendString("aaa")
})
app.Get("/second", func(c fiber.Ctx) error {
return c.SendString("bbbb")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/first", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/second", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
records := storage.recordedDeletes()
require.Len(t, records, 2)
var keys []string
for _, rec := range records {
keys = append(keys, rec.key)
require.Equal(t, "evict", rec.value)
require.True(t, rec.canceled)
}
require.ElementsMatch(t, []string{"/first_GET", "/first_GET_body"}, keys)
}
func TestCacheCleanupPropagatesRequestContextToDelete(t *testing.T) {
t.Parallel()
storage := newContextRecorderStorage()
storage.errs["set|/_GET"] = errors.New("boom")
var captured error
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
captured = err
return c.Status(fiber.StatusInternalServerError).SendString("storage failure")
},
})
app.Use(func(c fiber.Ctx) error {
c.SetContext(canceledContextWithMarker("cleanup"))
return c.Next()
})
app.Use(New(Config{Storage: storage, Expiration: time.Minute}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("payload")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Error(t, captured)
require.ErrorContains(t, captured, "cache: failed to store key")
records := storage.recordedDeletes()
require.Len(t, records, 1)
require.Equal(t, "/_GET_body", records[0].key)
require.Equal(t, "cleanup", records[0].value)
require.True(t, records[0].canceled)
}
func TestCacheStorageOperationsObserveRequestContext(t *testing.T) {
t.Parallel()
storage := newContextRecorderStorage()
app := fiber.New()
app.Use(func(c fiber.Ctx) error {
ctxLabel := string(c.Request().Header.Peek("X-Context"))
if ctxLabel == "" {
return c.Next()
}
canceled := string(c.Request().Header.Peek("X-Cancel")) == "true"
if canceled {
c.SetContext(canceledContextWithMarker(ctxLabel))
} else {
c.SetContext(contextWithMarker(ctxLabel))
}
return c.Next()
})
app.Use(New(Config{Storage: storage, Expiration: time.Minute}))
app.Get("/cache", func(c fiber.Ctx) error {
return c.SendString("payload")
})
firstReq := httptest.NewRequest(fiber.MethodGet, "/cache", http.NoBody)
firstReq.Header.Set("X-Context", "store")
firstReq.Header.Set("X-Cancel", "true")
resp, err := app.Test(firstReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
secondReq := httptest.NewRequest(fiber.MethodGet, "/cache", http.NoBody)
secondReq.Header.Set("X-Context", "fetch")
secondReq.Header.Set("X-Cancel", "true")
resp, err = app.Test(secondReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
setRecords := storage.recordedSets()
require.Len(t, setRecords, 2)
for _, rec := range setRecords {
require.Contains(t, []string{"/cache_GET", "/cache_GET_body"}, rec.key)
require.Equal(t, "store", rec.value)
require.True(t, rec.canceled)
}
getRecords := storage.recordedGets()
require.NotEmpty(t, getRecords)
var fetchEntry, fetchBody bool
for _, rec := range getRecords {
if rec.value != "fetch" {
continue
}
if rec.key == "/cache_GET" {
require.True(t, rec.canceled)
fetchEntry = true
}
if rec.key == "/cache_GET_body" {
require.True(t, rec.canceled)
fetchBody = true
}
}
require.True(t, fetchEntry, "expected cached entry retrieval to observe request context")
require.True(t, fetchBody, "expected cached body retrieval to observe request context")
}
func Test_Cache_CacheControl(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 10 * time.Second}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, "public, max-age=10", resp.Header.Get(fiber.HeaderCacheControl))
}
func Test_Cache_CacheControl_Disabled(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 10 * time.Second,
DisableCacheControl: true,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Empty(t, resp.Header.Get(fiber.HeaderCacheControl))
}
func Test_Cache_Expired(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 2 * time.Second}))
count := 0
app.Get("/", func(c fiber.Ctx) error {
count++
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// Sleep until the cache is expired
time.Sleep(3 * time.Second)
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
bodyCached, err := io.ReadAll(respCached.Body)
require.NoError(t, err)
if bytes.Equal(body, bodyCached) {
t.Errorf("Cache should have expired: %s, %s", body, bodyCached)
}
// Next response should be also cached
respCachedNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
bodyCachedNextRound, err := io.ReadAll(respCachedNextRound.Body)
require.NoError(t, err)
if !bytes.Equal(bodyCachedNextRound, bodyCached) {
t.Errorf("Cache should not have expired: %s, %s", bodyCached, bodyCachedNextRound)
}
}
func Test_Cache(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
count := 0
app.Get("/", func(c fiber.Ctx) error {
count++
return c.SendString(strconv.Itoa(count))
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
cachedReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
cachedResp, err := app.Test(cachedReq)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
cachedBody, err := io.ReadAll(cachedResp.Body)
require.NoError(t, err)
require.Equal(t, cachedBody, body)
}
// go test -run Test_Cache_WithNoCacheRequestDirective
func Test_Cache_WithNoCacheRequestDirective(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(fiber.Query(c, "id", "1"))
})
// Request id = 1
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
require.Equal(t, []byte("1"), body)
// Response cached, entry id = 1
// Request id = 2 without Cache-Control: no-cache
cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody)
cachedResp, err := app.Test(cachedReq)
require.NoError(t, err)
cachedBody, err := io.ReadAll(cachedResp.Body)
require.NoError(t, err)
require.Equal(t, cacheHit, cachedResp.Header.Get("X-Cache"))
require.Equal(t, []byte("1"), cachedBody)
// Response not cached, returns cached response, entry id = 1
// Request id = 2 with Cache-Control: no-cache
noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody)
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheResp, err := app.Test(noCacheReq)
require.NoError(t, err)
noCacheBody, err := io.ReadAll(noCacheResp.Body)
require.NoError(t, err)
require.Equal(t, cacheMiss, noCacheResp.Header.Get("X-Cache"))
require.Equal(t, []byte("2"), noCacheBody)
// Response cached, returns updated response, entry = 2
/* Check Test_Cache_WithETagAndNoCacheRequestDirective */
// Request id = 2 with Cache-Control: no-cache again
noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody)
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheResp1, err := app.Test(noCacheReq1)
require.NoError(t, err)
noCacheBody1, err := io.ReadAll(noCacheResp1.Body)
require.NoError(t, err)
require.Equal(t, cacheMiss, noCacheResp1.Header.Get("X-Cache"))
require.Equal(t, []byte("2"), noCacheBody1)
// Response cached, returns updated response, entry = 2
// Request id = 3 with Cache-Control: NO-CACHE
noCacheReqUpper := httptest.NewRequest(fiber.MethodGet, "/?id=3", http.NoBody)
noCacheReqUpper.Header.Set(fiber.HeaderCacheControl, "NO-CACHE")
noCacheRespUpper, err := app.Test(noCacheReqUpper)
require.NoError(t, err)
noCacheBodyUpper, err := io.ReadAll(noCacheRespUpper.Body)
require.NoError(t, err)
require.Equal(t, cacheMiss, noCacheRespUpper.Header.Get("X-Cache"))
require.Equal(t, []byte("3"), noCacheBodyUpper)
// Response cached, returns updated response, entry = 3
// Request id = 4 with Cache-Control: my-no-cache
invalidReq := httptest.NewRequest(fiber.MethodGet, "/?id=4", http.NoBody)
invalidReq.Header.Set(fiber.HeaderCacheControl, "my-no-cache")
invalidResp, err := app.Test(invalidReq)
require.NoError(t, err)
invalidBody, err := io.ReadAll(invalidResp.Body)
require.NoError(t, err)
require.Equal(t, cacheHit, invalidResp.Header.Get("X-Cache"))
require.Equal(t, []byte("3"), invalidBody)
// Response served from cache, existing entry = 3
// Request id = 4 again without Cache-Control: no-cache
cachedInvalidReq := httptest.NewRequest(fiber.MethodGet, "/?id=4", http.NoBody)
cachedInvalidResp, err := app.Test(cachedInvalidReq)
require.NoError(t, err)
cachedInvalidBody, err := io.ReadAll(cachedInvalidResp.Body)
require.NoError(t, err)
require.Equal(t, cacheHit, cachedInvalidResp.Header.Get("X-Cache"))
require.Equal(t, []byte("3"), cachedInvalidBody)
// Response cached, returns cached response, entry id = 3
// Request id = 1 without Cache-Control: no-cache
cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
cachedResp1, err := app.Test(cachedReq1)
require.NoError(t, err)
cachedBody1, err := io.ReadAll(cachedResp1.Body)
require.NoError(t, err)
require.Equal(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
require.Equal(t, []byte("3"), cachedBody1)
// Response not cached, returns cached response, entry id = 3
}
// go test -run Test_Cache_WithETagAndNoCacheRequestDirective
func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(
etag.New(),
New(),
)
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(fiber.Query(c, "id", "1"))
})
// Request id = 1
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// Response cached, entry id = 1
// If response status 200
etagToken := resp.Header.Get("Etag")
// Request id = 2 with ETag but without Cache-Control: no-cache
cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody)
cachedReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
cachedResp, err := app.Test(cachedReq)
require.NoError(t, err)
require.Equal(t, cacheHit, cachedResp.Header.Get("X-Cache"))
require.Equal(t, fiber.StatusNotModified, cachedResp.StatusCode)
// Response not cached, returns cached response, entry id = 1, status not modified
// Request id = 2 with ETag and Cache-Control: no-cache
noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody)
noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
noCacheResp, err := app.Test(noCacheReq)
require.NoError(t, err)
require.Equal(t, cacheMiss, noCacheResp.Header.Get("X-Cache"))
require.Equal(t, fiber.StatusOK, noCacheResp.StatusCode)
// Response cached, returns updated response, entry id = 2
// If response status 200
etagToken = noCacheResp.Header.Get("Etag")
// Request id = 3 with ETag and Cache-Control: NO-CACHE
noCacheReqUpper := httptest.NewRequest(fiber.MethodGet, "/?id=3", http.NoBody)
noCacheReqUpper.Header.Set(fiber.HeaderCacheControl, "NO-CACHE")
noCacheReqUpper.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
noCacheRespUpper, err := app.Test(noCacheReqUpper)
require.NoError(t, err)
require.Equal(t, cacheMiss, noCacheRespUpper.Header.Get("X-Cache"))
require.Equal(t, fiber.StatusOK, noCacheRespUpper.StatusCode)
// Response cached, returns updated response, entry id = 3
// Request id = 2 with ETag and Cache-Control: no-cache again
noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody)
noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache)
noCacheReq1.Header.Set(fiber.HeaderIfNoneMatch, etagToken)
noCacheResp1, err := app.Test(noCacheReq1)
require.NoError(t, err)
require.Equal(t, cacheMiss, noCacheResp1.Header.Get("X-Cache"))
require.Equal(t, fiber.StatusNotModified, noCacheResp1.StatusCode)
// Response cached, returns updated response, entry id = 2, status not modified
// Request id = 1 without ETag and Cache-Control: no-cache
cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
cachedResp1, err := app.Test(cachedReq1)
require.NoError(t, err)
require.Equal(t, cacheHit, cachedResp1.Header.Get("X-Cache"))
require.Equal(t, fiber.StatusOK, cachedResp1.StatusCode)
// Response not cached, returns cached response, entry id = 2
}
// go test -run Test_Cache_WithNoStoreRequestDirective
func Test_Cache_WithNoStoreRequestDirective(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(fiber.Query(c, "id", "1"))
})
// Request id = 2
noStoreReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody)
noStoreReq.Header.Set(fiber.HeaderCacheControl, noStore)
noStoreResp, err := app.Test(noStoreReq)
require.NoError(t, err)
noStoreBody, err := io.ReadAll(noStoreResp.Body)
require.NoError(t, err)
require.Equal(t, []byte("2"), noStoreBody)
// Response not cached, returns updated response
// Request id = 3 with Cache-Control: NO-STORE
noStoreReqUpper := httptest.NewRequest(fiber.MethodGet, "/?id=3", http.NoBody)
noStoreReqUpper.Header.Set(fiber.HeaderCacheControl, "NO-STORE")
noStoreRespUpper, err := app.Test(noStoreReqUpper)
require.NoError(t, err)
noStoreBodyUpper, err := io.ReadAll(noStoreRespUpper.Body)
require.NoError(t, err)
require.Equal(t, []byte("3"), noStoreBodyUpper)
// Response not cached, returns updated response
// Request id = 4 with Cache-Control: my-no-store
invalidReq := httptest.NewRequest(fiber.MethodGet, "/?id=4", http.NoBody)
invalidReq.Header.Set(fiber.HeaderCacheControl, "my-no-store")
invalidResp, err := app.Test(invalidReq)
require.NoError(t, err)
invalidBody, err := io.ReadAll(invalidResp.Body)
require.NoError(t, err)
require.Equal(t, cacheMiss, invalidResp.Header.Get("X-Cache"))
require.Equal(t, []byte("4"), invalidBody)
// Response cached, returns updated response, entry = 4
// Request id = 4 again without Cache-Control
cachedInvalidReq := httptest.NewRequest(fiber.MethodGet, "/?id=4", http.NoBody)
cachedInvalidResp, err := app.Test(cachedInvalidReq)
require.NoError(t, err)
cachedInvalidBody, err := io.ReadAll(cachedInvalidResp.Body)
require.NoError(t, err)
require.Equal(t, cacheHit, cachedInvalidResp.Header.Get("X-Cache"))
require.Equal(t, []byte("4"), cachedInvalidBody)
// Response cached previously, served from cache
}
func Test_Cache_WithSeveralRequests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 10 * time.Second,
}))
app.Get("/:id", func(c fiber.Ctx) error {
return c.SendString(c.Params("id"))
})
for range 10 {
for i := range 10 {
func(id int) {
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, fmt.Sprintf("/%d", id), http.NoBody))
require.NoError(t, err)
defer func(body io.ReadCloser) {
closeErr := body.Close()
require.NoError(t, closeErr)
}(rsp.Body)
idFromServ, err := io.ReadAll(rsp.Body)
require.NoError(t, err)
a, err := strconv.Atoi(string(idFromServ))
require.NoError(t, err)
// Sometimes, the id is not equal to a
require.Equal(t, id, a)
}(i)
}
}
}
func Test_Cache_Invalid_Expiration(t *testing.T) {
t.Parallel()
app := fiber.New()
cache := New(Config{Expiration: 0 * time.Second})
app.Use(cache)
count := 0
app.Get("/", func(c fiber.Ctx) error {
count++
return c.SendString(strconv.Itoa(count))
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
cachedReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
cachedResp, err := app.Test(cachedReq)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
cachedBody, err := io.ReadAll(cachedResp.Body)
require.NoError(t, err)
require.Equal(t, cachedBody, body)
}
func Test_Cache_Get(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Post("/", func(c fiber.Ctx) error {
return c.SendString(fiber.Query[string](c, "cache"))
})
app.Get("/get", func(c fiber.Ctx) error {
return c.SendString(fiber.Query[string](c, "cache"))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", http.NoBody))
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", http.NoBody))
require.NoError(t, err)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "12345", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", http.NoBody))
require.NoError(t, err)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", http.NoBody))
require.NoError(t, err)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "123", string(body))
}
func Test_Cache_Post(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Methods: []string{fiber.MethodPost},
}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendString(fiber.Query[string](c, "cache"))
})
app.Get("/get", func(c fiber.Ctx) error {
return c.SendString(fiber.Query[string](c, "cache"))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", http.NoBody))
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", http.NoBody))
require.NoError(t, err)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", http.NoBody))
require.NoError(t, err)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "123", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", http.NoBody))
require.NoError(t, err)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "12345", string(body))
}
func Test_Cache_NothingToCache(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: -(time.Second * 1)}))
count := 0
app.Get("/", func(c fiber.Ctx) error {
count++
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
time.Sleep(500 * time.Millisecond)
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
bodyCached, err := io.ReadAll(respCached.Body)
require.NoError(t, err)
if bytes.Equal(body, bodyCached) {
t.Errorf("Cache should have expired: %s, %s", body, bodyCached)
}
}
func Test_Cache_CustomNext(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(c fiber.Ctx) bool {
return c.Response().StatusCode() != fiber.StatusOK
},
}))
count := 0
app.Get("/", func(c fiber.Ctx) error {
count++
return c.SendString(strconv.Itoa(count))
})
errorCount := 0
app.Get("/error", func(c fiber.Ctx) error {
errorCount++
return c.Status(fiber.StatusInternalServerError).SendString(strconv.Itoa(errorCount))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
bodyCached, err := io.ReadAll(respCached.Body)
require.NoError(t, err)
require.True(t, bytes.Equal(body, bodyCached))
require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl))
_, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/error", http.NoBody))
require.NoError(t, err)
errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", http.NoBody))
require.NoError(t, err)
require.Empty(t, errRespCached.Header.Get(fiber.HeaderCacheControl))
}
func Test_CustomKey(t *testing.T) {
t.Parallel()
app := fiber.New()
var called bool
app.Use(New(Config{KeyGenerator: func(c fiber.Ctx) string {
called = true
return utils.CopyString(c.Path())
}}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("hi")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
_, err := app.Test(req)
require.NoError(t, err)
require.True(t, called)
}
func Test_CustomExpiration(t *testing.T) {
t.Parallel()
app := fiber.New()
var called bool
var newCacheTime int
app.Use(New(Config{ExpirationGenerator: func(c fiber.Ctx, _ *Config) time.Duration {
called = true
var err error
newCacheTime, err = strconv.Atoi(c.GetRespHeader("Cache-Time", "600"))
require.NoError(t, err)
return time.Second * time.Duration(newCacheTime)
}}))
count := 0
app.Get("/", func(c fiber.Ctx) error {
count++
c.Response().Header.Add("Cache-Time", "1")
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.True(t, called)
require.Equal(t, 1, newCacheTime)
// Wait until the cache expires (timestamp tick can delay expiry detection slightly).
expireDeadline := time.Now().Add(3 * time.Second)
var cachedResp *http.Response
for {
cachedResp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
if cachedResp.Header.Get("X-Cache") != cacheHit {
break
}
require.True(t, time.Now().Before(expireDeadline), "response remained cached beyond expected expiration")
time.Sleep(50 * time.Millisecond)
}
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
cachedBody, err := io.ReadAll(cachedResp.Body)
require.NoError(t, err)
if bytes.Equal(body, cachedBody) {
t.Errorf("Cache should have expired: %s, %s", body, cachedBody)
}
// Next response should be cached
cachedRespNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
cachedBodyNextRound, err := io.ReadAll(cachedRespNextRound.Body)
require.NoError(t, err)
if !bytes.Equal(cachedBodyNextRound, cachedBody) {
t.Errorf("Cache should not have expired: %s, %s", cachedBodyNextRound, cachedBody)
}
}
func Test_AdditionalE2EResponseHeaders(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
StoreResponseHeaders: true,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Response().Header.Add("X-Foobar", "foobar")
return c.SendString("hi")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, "foobar", resp.Header.Get("X-Foobar"))
req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, "foobar", resp.Header.Get("X-Foobar"))
}
func Test_CacheHeader(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 10 * time.Second,
Next: func(c fiber.Ctx) bool {
return c.Response().StatusCode() != fiber.StatusOK
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
app.Post("/", func(c fiber.Ctx) error {
return c.SendString(fiber.Query[string](c, "cache"))
})
count := 0
app.Get("/error", func(c fiber.Ctx) error {
count++
c.Response().Header.Add("Cache-Time", "1")
return c.Status(fiber.StatusInternalServerError).SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, errRespCached.Header.Get("X-Cache"))
}
func Test_Cache_WithHead(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
count := 0
handler := func(c fiber.Ctx) error {
count++
c.Response().Header.Add("Cache-Time", "1")
return c.SendString(strconv.Itoa(count))
}
app.RouteChain("/").Get(handler).Head(handler)
req := httptest.NewRequest(fiber.MethodHead, "/", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
cachedReq := httptest.NewRequest(fiber.MethodHead, "/", http.NoBody)
cachedResp, err := app.Test(cachedReq)
require.NoError(t, err)
require.Equal(t, cacheHit, cachedResp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
cachedBody, err := io.ReadAll(cachedResp.Body)
require.NoError(t, err)
require.Equal(t, cachedBody, body)
}
func Test_Cache_WithHeadThenGet(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
handler := func(c fiber.Ctx) error {
return c.SendString(fiber.Query[string](c, "cache"))
}
app.RouteChain("/").Get(handler).Head(handler)
headResp, err := app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", http.NoBody))
require.NoError(t, err)
headBody, err := io.ReadAll(headResp.Body)
require.NoError(t, err)
require.Empty(t, string(headBody))
require.Equal(t, cacheMiss, headResp.Header.Get("X-Cache"))
headResp, err = app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", http.NoBody))
require.NoError(t, err)
headBody, err = io.ReadAll(headResp.Body)
require.NoError(t, err)
require.Empty(t, string(headBody))
require.Equal(t, cacheHit, headResp.Header.Get("X-Cache"))
getResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", http.NoBody))
require.NoError(t, err)
getBody, err := io.ReadAll(getResp.Body)
require.NoError(t, err)
require.Equal(t, "123", string(getBody))
require.Equal(t, cacheMiss, getResp.Header.Get("X-Cache"))
getResp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", http.NoBody))
require.NoError(t, err)
getBody, err = io.ReadAll(getResp.Body)
require.NoError(t, err)
require.Equal(t, "123", string(getBody))
require.Equal(t, cacheHit, getResp.Header.Get("X-Cache"))
}
func Test_CustomCacheHeader(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
CacheHeader: "Cache-Status",
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("Cache-Status"))
}
func Test_CacheInvalidation(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
CacheInvalidator: func(c fiber.Ctx) bool {
return fiber.Query[bool](c, "invalidate")
},
}))
count := 0
app.Get("/", func(c fiber.Ctx) error {
count++
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
bodyCached, err := io.ReadAll(respCached.Body)
require.NoError(t, err)
require.True(t, bytes.Equal(body, bodyCached))
require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl))
respInvalidate, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", http.NoBody))
require.NoError(t, err)
bodyInvalidate, err := io.ReadAll(respInvalidate.Body)
require.NoError(t, err)
require.NotEqual(t, body, bodyInvalidate)
}
func Test_CacheInvalidation_noCacheEntry(t *testing.T) {
t.Parallel()
t.Run("Cache Invalidator should not be called if no cache entry exist ", func(t *testing.T) {
t.Parallel()
app := fiber.New()
cacheInvalidatorExecuted := false
app.Use(New(Config{
CacheInvalidator: func(c fiber.Ctx) bool {
cacheInvalidatorExecuted = true
return fiber.Query[bool](c, "invalidate")
},
MaxBytes: 10 * 1024 * 1024,
}))
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", http.NoBody))
require.NoError(t, err)
require.False(t, cacheInvalidatorExecuted)
})
}
func Test_CacheInvalidation_removeFromHeap(t *testing.T) {
t.Parallel()
t.Run("Invalidate and remove from the heap", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
CacheInvalidator: func(c fiber.Ctx) bool {
return fiber.Query[bool](c, "invalidate")
},
MaxBytes: 10 * 1024 * 1024,
}))
count := 0
app.Get("/", func(c fiber.Ctx) error {
count++
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
bodyCached, err := io.ReadAll(respCached.Body)
require.NoError(t, err)
require.True(t, bytes.Equal(body, bodyCached))
require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl))
respInvalidate, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", http.NoBody))
require.NoError(t, err)
bodyInvalidate, err := io.ReadAll(respInvalidate.Body)
require.NoError(t, err)
require.NotEqual(t, body, bodyInvalidate)
})
}
func Test_CacheStorage_CustomHeaders(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Storage: memory.New(),
MaxBytes: 10 * 1024 * 1024,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Response().Header.Set("Content-Type", "text/xml")
c.Response().Header.Set("Content-Encoding", "utf8")
return c.Send([]byte("Test"))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
bodyCached, err := io.ReadAll(respCached.Body)
require.NoError(t, err)
require.True(t, bytes.Equal(body, bodyCached))
require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl))
}
// Because time points are updated once every X milliseconds, entries in tests can often have
// equal expiration times and thus be in a random order. This closure hands out increasing
// time intervals to maintain strong ascending order of expiration
func stableAscendingExpiration() func(c1 fiber.Ctx, c2 *Config) time.Duration {
i := 0
return func(_ fiber.Ctx, _ *Config) time.Duration {
i++
return time.Hour * time.Duration(i)
}
}
func Test_Cache_MaxBytesOrder(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxBytes: 2,
ExpirationGenerator: stableAscendingExpiration(),
}))
app.Get("/*", func(c fiber.Ctx) error {
return c.SendString("1")
})
cases := [][]string{
// Insert a, b into cache of size 2 bytes (responses are 1 byte)
{"/a", cacheMiss},
{"/b", cacheMiss},
{"/a", cacheHit},
{"/b", cacheHit},
// Add c -> a evicted
{"/c", cacheMiss},
{"/b", cacheHit},
// Add a again -> b evicted
{"/a", cacheMiss},
{"/c", cacheHit},
// Add b -> c evicted
{"/b", cacheMiss},
{"/c", cacheMiss},
}
for idx, tcase := range cases {
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], http.NoBody))
require.NoError(t, err)
require.Equal(t, tcase[1], rsp.Header.Get("X-Cache"), "Case %v", idx)
}
}
func Test_Cache_MaxBytesSizes(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxBytes: 7,
ExpirationGenerator: stableAscendingExpiration(),
}))
app.Get("/*", func(c fiber.Ctx) error {
path := c.RequestCtx().URI().LastPathSegment()
size, err := strconv.Atoi(string(path))
require.NoError(t, err)
return c.Send(make([]byte, size))
})
cases := [][]string{
{"/1", cacheMiss},
{"/2", cacheMiss},
{"/3", cacheMiss},
{"/4", cacheMiss}, // 1+2+3+4 > 7 => 1,2 are evicted now
{"/3", cacheHit},
{"/1", cacheMiss},
{"/2", cacheMiss},
{"/8", cacheUnreachable}, // too big to cache -> unreachable
}
for idx, tcase := range cases {
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], http.NoBody))
require.NoError(t, err)
require.Equal(t, tcase[1], rsp.Header.Get("X-Cache"), "Case %v", idx)
}
}
func Test_Cache_UncacheableStatusCodes(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/:statusCode", func(c fiber.Ctx) error {
statusCode, err := strconv.Atoi(c.Params("statusCode"))
require.NoError(t, err)
return c.Status(statusCode).SendString("foo")
})
uncacheableStatusCodes := []int{
// Informational responses
fiber.StatusContinue,
fiber.StatusSwitchingProtocols,
fiber.StatusProcessing,
fiber.StatusEarlyHints,
// Successful responses
fiber.StatusCreated,
fiber.StatusAccepted,
fiber.StatusResetContent,
fiber.StatusMultiStatus,
fiber.StatusAlreadyReported,
fiber.StatusIMUsed,
// Redirection responses
fiber.StatusFound,
fiber.StatusSeeOther,
fiber.StatusNotModified,
fiber.StatusUseProxy,
fiber.StatusSwitchProxy,
fiber.StatusTemporaryRedirect,
// Client error responses
fiber.StatusBadRequest,
fiber.StatusUnauthorized,
fiber.StatusPaymentRequired,
fiber.StatusForbidden,
fiber.StatusNotAcceptable,
fiber.StatusProxyAuthRequired,
fiber.StatusRequestTimeout,
fiber.StatusConflict,
fiber.StatusLengthRequired,
fiber.StatusPreconditionFailed,
fiber.StatusRequestEntityTooLarge,
fiber.StatusUnsupportedMediaType,
fiber.StatusRequestedRangeNotSatisfiable,
fiber.StatusExpectationFailed,
fiber.StatusMisdirectedRequest,
fiber.StatusUnprocessableEntity,
fiber.StatusLocked,
fiber.StatusFailedDependency,
fiber.StatusTooEarly,
fiber.StatusUpgradeRequired,
fiber.StatusPreconditionRequired,
fiber.StatusTooManyRequests,
fiber.StatusRequestHeaderFieldsTooLarge,
fiber.StatusTeapot,
fiber.StatusUnavailableForLegalReasons,
// Server error responses
fiber.StatusInternalServerError,
fiber.StatusBadGateway,
fiber.StatusServiceUnavailable,
fiber.StatusGatewayTimeout,
fiber.StatusHTTPVersionNotSupported,
fiber.StatusVariantAlsoNegotiates,
fiber.StatusInsufficientStorage,
fiber.StatusLoopDetected,
fiber.StatusNotExtended,
fiber.StatusNetworkAuthenticationRequired,
}
for _, v := range uncacheableStatusCodes {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, fmt.Sprintf("/%d", v), http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
require.Equal(t, v, resp.StatusCode)
}
}
func TestCacheAgeHeader(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 10 * time.Second}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("ok") })
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, "0", resp.Header.Get(fiber.HeaderAge))
time.Sleep(4 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
age, err := strconv.Atoi(resp.Header.Get(fiber.HeaderAge))
require.NoError(t, err)
require.Positive(t, age)
}
func TestCacheUpstreamAge(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 3 * time.Second}))
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderAge, "5")
return c.SendString("hi")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, "5", resp.Header.Get(fiber.HeaderAge))
time.Sleep(1500 * time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
require.Equal(t, "5", resp.Header.Get(fiber.HeaderAge))
}
func Test_CacheRequestMaxAgeRevalidates(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 30 * time.Second,
KeyGenerator: func(c fiber.Ctx) string {
return c.Path() + "|req-max-age-zero"
},
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=30")
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderCacheControl, "max-age=0")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
}
func Test_CacheExpiresFutureAllowsCaching(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
StoreResponseHeaders: true,
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderExpires, time.Now().Add(30*time.Second).UTC().Format(time.RFC1123))
return c.SendString("expires" + strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "expires1", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "expires1", string(body))
}
func Test_CacheExpiresPastPreventsCaching(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderExpires, time.Now().Add(-1*time.Minute).UTC().Format(time.RFC1123))
return c.SendString("expires" + strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "expires1", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "expires2", string(body))
}
func Test_CacheAllowsSharedCacheMustRevalidateWithAuthorization(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 30 * time.Second,
KeyGenerator: func(c fiber.Ctx) string {
return c.Path() + "|must-revalidate-auth"
},
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "must-revalidate, max-age=60")
return c.SendString("auth" + strconv.Itoa(count))
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "auth1", string(body))
req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "auth1", string(body))
}
func Test_CacheAllowsSharedCacheProxyRevalidateWithAuthorization(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 30 * time.Second,
KeyGenerator: func(c fiber.Ctx) string {
return c.Path() + "|proxy-revalidate-auth"
},
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "proxy-revalidate, max-age=60")
return c.SendString("proxy" + strconv.Itoa(count))
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "proxy1", string(body))
req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "proxy1", string(body))
}
func Test_CacheInvalidExpiresStoredAsStale(t *testing.T) {
t.Parallel()
storage := newFailingCacheStorage()
app := fiber.New()
app.Use(New(Config{
Expiration: 30 * time.Second,
KeyGenerator: func(c fiber.Ctx) string {
return c.Path() + "|invalid-expires"
},
Storage: storage,
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public")
c.Set(fiber.HeaderExpires, "invalid-date")
return c.SendString("body" + strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "body1", string(body))
expectedKey := "/|invalid-expires_GET"
require.Contains(t, storage.data, expectedKey)
require.Contains(t, storage.data, expectedKey+"_body")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "body2", string(body))
require.Contains(t, storage.data, expectedKey)
require.Contains(t, storage.data, expectedKey+"_body")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "body3", string(body))
require.Contains(t, storage.data, expectedKey)
require.Contains(t, storage.data, expectedKey+"_body")
}
func Test_CacheSMaxAgeOverridesMaxAgeWhenShorter(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=10, s-maxage=1")
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
time.Sleep(1700 * time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
}
func Test_CacheSMaxAgeOverridesMaxAgeWhenLonger(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=1, s-maxage=2")
return c.SendString(strconv.Itoa(count))
})
for time.Now().Nanosecond() >= int(100*time.Millisecond) {
time.Sleep(10 * time.Millisecond)
}
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
time.Sleep(1200 * time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
time.Sleep(1700 * time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
}
func Test_CacheOnlyIfCachedMiss(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
return c.SendString("ok")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderCacheControl, "only-if-cached")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusGatewayTimeout, resp.StatusCode)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
require.Equal(t, 0, count)
}
func Test_CacheOnlyIfCachedStaleNotServed(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=1")
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
time.Sleep(1500 * time.Millisecond)
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderCacheControl, "only-if-cached")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusGatewayTimeout, resp.StatusCode)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
require.Equal(t, 1, count)
}
func Test_CacheMaxStaleServesStaleResponse(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=2")
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
time.Sleep(2500 * time.Millisecond)
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderCacheControl, "max-stale=5")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equalf(t, cacheHit, resp.Header.Get("X-Cache"), "dirs=%+v Age=%s count=%d", parseRequestCacheControlString("max-stale=5"), resp.Header.Get(fiber.HeaderAge), count)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
require.Equal(t, 1, count)
}
func Test_CacheMaxStaleRespectsMustRevalidate(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=1, must-revalidate")
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
time.Sleep(1500 * time.Millisecond)
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderCacheControl, "max-stale=30")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
require.Equal(t, 2, count)
}
func Test_CacheMaxStaleRespectsProxyRevalidateSharedAuth(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "s-maxage=1, proxy-revalidate")
return c.SendString(strconv.Itoa(count))
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer abc")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
time.Sleep(1500 * time.Millisecond)
req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer abc")
req.Header.Set(fiber.HeaderCacheControl, "max-stale=30")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
require.Equal(t, 2, count)
}
func Test_CachePreservesCacheControlHeaders(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
expires := time.Now().Add(10 * time.Second).UTC().Format(http.TimeFormat)
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "public, max-age=5, immutable")
c.Set(fiber.HeaderExpires, expires)
c.Set(fiber.HeaderETag, `W/"abc"`)
return c.SendString("ok")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
require.Equal(t, "public, max-age=5, immutable", resp.Header.Get(fiber.HeaderCacheControl))
require.Equal(t, expires, resp.Header.Get(fiber.HeaderExpires))
require.Equal(t, `W/"abc"`, resp.Header.Get(fiber.HeaderETag))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
require.Equal(t, "public, max-age=5, immutable", resp.Header.Get(fiber.HeaderCacheControl))
require.Equal(t, expires, resp.Header.Get(fiber.HeaderExpires))
require.Equal(t, `W/"abc"`, resp.Header.Get(fiber.HeaderETag))
}
func setResponseDate(date time.Time) fiber.Handler {
return func(c fiber.Ctx) error {
if err := c.Next(); err != nil {
return err
}
c.Response().Header.Set(fiber.HeaderDate, date.UTC().Format(http.TimeFormat))
return nil
}
}
func Test_CacheDateAndAgeHandling(t *testing.T) {
t.Parallel()
type testCase struct {
name string
cacheControl string
cacheHeader string
dateOffset time.Duration
expiration time.Duration
expectAgeAtLeast int
expectCount int
originAge int
}
cases := []testCase{
{
name: "age derived from past date without Age header",
dateOffset: -1 * time.Minute,
cacheControl: "public, max-age=120",
cacheHeader: cacheHit,
expiration: 5 * time.Minute,
expectAgeAtLeast: 1,
expectCount: 1,
},
{
name: "stale due to past date despite max-age",
dateOffset: -90 * time.Second,
cacheControl: "public, max-age=30",
cacheHeader: cacheUnreachable,
expiration: 5 * time.Minute,
expectCount: 2,
originAge: 90,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: tc.expiration}))
app.Use(setResponseDate(time.Now().Add(tc.dateOffset).UTC()))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
if tc.originAge > 0 {
c.Response().Header.Set(fiber.HeaderAge, strconv.Itoa(tc.originAge))
}
c.Set(fiber.HeaderCacheControl, tc.cacheControl)
return c.SendString(strconv.Itoa(count))
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
if tc.cacheHeader == cacheHit {
time.Sleep(2 * time.Second)
}
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, tc.cacheHeader, resp.Header.Get("X-Cache"))
if tc.cacheHeader == cacheHit {
ageVal, err := strconv.Atoi(resp.Header.Get(fiber.HeaderAge))
require.NoError(t, err)
require.GreaterOrEqual(t, ageVal, tc.expectAgeAtLeast)
require.Equal(t, 1, count)
} else {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, strconv.Itoa(tc.expectCount), string(body))
require.Equal(t, tc.expectCount, count)
}
})
}
}
func Test_CacheClampsInvalidStoredDate(t *testing.T) {
t.Parallel()
storage := newMutatingStorage(func(key string, val []byte) []byte {
if strings.HasSuffix(key, "_body") {
return val
}
var it item
if _, err := it.UnmarshalMsg(val); err != nil {
return val
}
it.date = uint64(math.MaxInt64) + 1024
updated, err := it.MarshalMsg(nil)
if err != nil {
return val
}
return updated
})
app := fiber.New()
app.Use(New(Config{
Expiration: time.Minute,
Storage: storage,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "public, max-age=60")
return c.SendString("ok")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
parsedDate, err := http.ParseTime(resp.Header.Get(fiber.HeaderDate))
require.NoError(t, err)
require.WithinDuration(t, time.Now(), parsedDate, time.Minute)
ageVal, err := strconv.Atoi(resp.Header.Get(fiber.HeaderAge))
require.NoError(t, err)
require.Less(t, ageVal, 60)
require.GreaterOrEqual(t, ageVal, 0)
}
func Test_CacheClampsFutureStoredDate(t *testing.T) {
t.Parallel()
storage := newMutatingStorage(func(key string, val []byte) []byte {
if strings.HasSuffix(key, "_body") {
return val
}
var it item
if _, err := it.UnmarshalMsg(val); err != nil {
return val
}
future := time.Now().Add(2 * time.Second).UTC()
sec := future.Unix()
if sec < 0 {
sec = 0
}
it.date = uint64(sec) //nolint:gosec // safe: sec is clamped to non-negative range
updated, err := it.MarshalMsg(nil)
if err != nil {
return val
}
return updated
})
app := fiber.New()
app.Use(New(Config{
Expiration: time.Minute,
Storage: storage,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "public, max-age=60")
return c.SendString("ok")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
parsedDate, err := http.ParseTime(resp.Header.Get(fiber.HeaderDate))
require.NoError(t, err)
require.False(t, parsedDate.After(time.Now()))
ageVal, err := strconv.Atoi(resp.Header.Get(fiber.HeaderAge))
require.NoError(t, err)
require.GreaterOrEqual(t, ageVal, 0)
}
func Test_RequestPragmaNoCacheTriggersMiss(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: time.Minute,
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=60")
return c.SendString("body" + strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "body1", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "body1", string(body))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderPragma, "no-cache")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "body2", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "body2", string(body))
}
func Test_CacheStaleResponseAddsWarning110(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 2 * time.Second,
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=1")
return c.SendString("body" + strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderCacheControl, "max-stale=5")
// Wait for the cached response to become stale (max-age=1)
// Add extra time to ensure the entry has expired
time.Sleep(1200 * time.Millisecond)
deadline := time.Now().Add(3 * time.Second)
for {
resp, err = app.Test(req)
require.NoError(t, err)
if resp.Header.Get("X-Cache") == cacheHit {
ageVal, err := strconv.Atoi(resp.Header.Get(fiber.HeaderAge))
require.NoError(t, err)
if ageVal >= 1 {
// Check that Warning header is present before breaking
warnings := resp.Header.Values(fiber.HeaderWarning)
if len(warnings) > 0 {
break
}
}
}
require.True(t, time.Now().Before(deadline), "response did not become stale before deadline")
time.Sleep(50 * time.Millisecond)
}
warnings := resp.Header.Values(fiber.HeaderWarning)
require.NotEmpty(t, warnings, "Warning header should be present when serving stale response")
found := false
for _, w := range warnings {
if strings.Contains(w, "110") {
found = true
break
}
}
require.True(t, found, "warning 110 not found in %v", warnings)
}
func Test_CacheHeuristicFreshnessAddsWarning113(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 2 * time.Second,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "public, max-age=60")
return c.SendString("body")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
for _, w := range resp.Header.Values(fiber.HeaderWarning) {
require.NotContains(t, w, "113", "warning 113 should not be present for explicitly fresh responses")
}
}
func Test_CacheHeuristicFreshnessAddsWarning113AfterThreshold(t *testing.T) {
t.Parallel()
storage := newMutatingStorage(func(key string, val []byte) []byte {
if strings.HasSuffix(key, "_body") {
return val
}
var it item
if _, err := it.UnmarshalMsg(val); err != nil {
return val
}
oldDate := time.Now().Add(-25 * time.Hour).UTC()
sec := oldDate.Unix()
if sec < 0 {
sec = 0
}
it.date = uint64(sec) //nolint:gosec // safe: sec is clamped to non-negative range
future := time.Now().Add(48 * time.Hour).UTC()
expSec := future.Unix()
if expSec < 0 {
expSec = 0
}
it.exp = uint64(expSec) //nolint:gosec // safe: expSec is clamped to non-negative range
it.ttl = uint64((48 * time.Hour) / time.Second)
updated, err := it.MarshalMsg(nil)
if err != nil {
return val
}
return updated
})
app := fiber.New()
app.Use(New(Config{
Expiration: 2 * time.Second,
Storage: storage,
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
return c.SendString("body" + strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
warnings := resp.Header.Values(fiber.HeaderWarning)
require.NotEmpty(t, warnings)
found := false
for _, w := range warnings {
if strings.Contains(w, "113") {
found = true
break
}
}
require.True(t, found, "warning 113 not found in %v", warnings)
}
func Test_CacheAgeHeaderIsCappedAtMaxDeltaSeconds(t *testing.T) {
t.Parallel()
const veryLargeAge = uint64(math.MaxInt32) + 1000
storage := newMutatingStorage(func(key string, val []byte) []byte {
if strings.HasSuffix(key, "_body") {
return val
}
var it item
if _, err := it.UnmarshalMsg(val); err != nil {
return val
}
it.age = veryLargeAge
updated, err := it.MarshalMsg(nil)
if err != nil {
return val
}
return updated
})
app := fiber.New()
app.Use(New(Config{
Expiration: time.Minute,
Storage: storage,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "public, max-age=60")
return c.SendString("body")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
ageVal, err := strconv.Atoi(resp.Header.Get(fiber.HeaderAge))
require.NoError(t, err)
require.Equal(t, math.MaxInt32, ageVal)
}
func Test_CacheMinFreshForcesRevalidation(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=5")
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderCacheControl, "min-fresh=10")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equalf(t, cacheMiss, resp.Header.Get("X-Cache"), "dirs=%+v Age=%s count=%d", parseRequestCacheControlString("min-fresh=10"), resp.Header.Get(fiber.HeaderAge), count)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
}
func Test_CachePermanentRedirectCached(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 30 * time.Second,
StoreResponseHeaders: true,
KeyGenerator: func(c fiber.Ctx) string {
return c.Path() + "|status-308"
},
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=30")
c.Set(fiber.HeaderLocation, "/dest")
return c.Status(fiber.StatusPermanentRedirect).SendString("redir" + strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
require.Equal(t, fiber.StatusPermanentRedirect, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "redir1", string(body))
require.Equal(t, "/dest", resp.Header.Get(fiber.HeaderLocation))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
require.Equal(t, fiber.StatusPermanentRedirect, resp.StatusCode)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "redir1", string(body))
require.Equal(t, "/dest", resp.Header.Get(fiber.HeaderLocation))
}
func Test_CacheNoStoreDirective(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "no-store")
return c.SendString("ok")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
}
func Test_CacheNoCacheDirective(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "no-cache")
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
}
func Test_CacheNoCacheDirectiveOverridesExistingEntry(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var noCacheMode atomic.Bool
app.Get("/", func(c fiber.Ctx) error {
if noCacheMode.Load() {
c.Set(fiber.HeaderCacheControl, "no-cache")
return c.SendString("no-cache")
}
c.Set(fiber.HeaderCacheControl, "public, max-age=60")
return c.SendString("cacheable")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "cacheable", string(body))
noCacheMode.Store(true)
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderCacheControl, "no-cache")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "no-cache", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "no-cache", string(body))
}
func Test_CacheRespectsUpstreamAgeForFreshness(t *testing.T) {
t.Parallel()
t.Run("skipsCachingWhenAgeExhaustsFreshness", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
KeyGenerator: func(c fiber.Ctx) string {
return c.Path() + "|age-exhausted"
},
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=2")
c.Set(fiber.HeaderAge, "2")
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
})
t.Run("expiresAfterRemainingLifetime", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
KeyGenerator: func(c fiber.Ctx) string {
return c.Path() + "|age-remaining"
},
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=2")
c.Set(fiber.HeaderAge, "1")
return c.SendString(strconv.Itoa(count))
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
time.Sleep(1500 * time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
})
}
func Test_CacheVarySeparatesVariants(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
KeyGenerator: func(c fiber.Ctx) string {
return c.Path() + "|vary-separated"
},
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderVary, fiber.HeaderAcceptLanguage)
return c.SendString(c.Get(fiber.HeaderAcceptLanguage) + strconv.Itoa(count))
})
reqEN := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
reqEN.Header.Set(fiber.HeaderAcceptLanguage, "en")
resp, err := app.Test(reqEN)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "en1", string(body))
reqFR := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
reqFR.Header.Set(fiber.HeaderAcceptLanguage, "fr")
resp, err = app.Test(reqFR)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "fr2", string(body))
reqENRepeat := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
reqENRepeat.Header.Set(fiber.HeaderAcceptLanguage, "en")
resp, err = app.Test(reqENRepeat)
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "en1", string(body))
reqFRRepeat := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
reqFRRepeat.Header.Set(fiber.HeaderAcceptLanguage, "fr")
resp, err = app.Test(reqFRRepeat)
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "fr2", string(body))
}
func Test_CacheVaryStarUncacheable(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
KeyGenerator: func(c fiber.Ctx) string {
return c.Path() + "|vary-star"
},
}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderVary, "*")
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
}
func Test_CachePrivateDirective(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "private")
return c.SendString(strconv.Itoa(count))
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
}
func Test_CachePrivateDirectiveWithAuthorization(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "private")
return c.SendString(strconv.Itoa(count))
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
}
func Test_CachePrivateDirectiveInvalidatesExistingEntry(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var privateMode atomic.Bool
app.Get("/", func(c fiber.Ctx) error {
if privateMode.Load() {
c.Set(fiber.HeaderCacheControl, "private")
return c.SendString("private")
}
c.Set(fiber.HeaderCacheControl, "public, max-age=60")
return c.SendString("public")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "public", string(body))
privateMode.Store(true)
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderCacheControl, "no-cache")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "private", string(body))
privateMode.Store(false)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "public", string(body))
}
func Test_CacheControlNotOverwritten(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 10 * time.Second, StoreResponseHeaders: true}))
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "private")
return c.SendString("ok")
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, "private", resp.Header.Get(fiber.HeaderCacheControl))
}
func Test_CacheMaxAgeDirective(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 10 * time.Second}))
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "max-age=1")
return c.SendString("1")
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
time.Sleep(1500 * time.Millisecond)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
}
func Test_ParseMaxAge(t *testing.T) {
t.Parallel()
tests := []struct {
header string
expect time.Duration
ok bool
}{
{"max-age=60", 60 * time.Second, true},
{"public, max-age=86400", 86400 * time.Second, true},
{"no-store", 0, false},
{"max-age=invalid", 0, false},
{"public, s-maxage=100, max-age=50", 50 * time.Second, true},
{"MAX-AGE=20", 20 * time.Second, true},
{"public , max-age=0", 0, true},
{"public , max-age", 0, false},
}
for _, tt := range tests {
t.Run(tt.header, func(t *testing.T) {
t.Parallel()
d, ok := parseMaxAge(tt.header)
if tt.ok != ok {
t.Fatalf("expected ok=%v got %v", tt.ok, ok)
}
if ok && d != tt.expect {
t.Fatalf("expected %v got %v", tt.expect, d)
}
})
}
}
func Test_AllowsSharedCache(t *testing.T) {
t.Parallel()
tests := []struct {
directives string
expect bool
}{
{"public", true},
{"private", false},
{"s-maxage=60", true},
{"public, max-age=60", true},
{"public, must-revalidate", true},
{"max-age=60", false},
{"no-cache", false},
{"no-cache, s-maxage=60", true},
{"", false},
}
for _, tt := range tests {
t.Run(tt.directives, func(t *testing.T) {
t.Parallel()
got := allowsSharedCache(tt.directives)
require.Equal(t, tt.expect, got, "directives: %q", tt.directives)
})
}
t.Run("private overrules public", func(t *testing.T) {
t.Parallel()
got := allowsSharedCache(strings.ToUpper("private, public"))
require.False(t, got)
})
}
func TestCacheSkipsAuthorizationByDefault(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
return c.SendString(strconv.Itoa(count))
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
}
func TestCacheBypassesExistingEntryForAuthorization(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
return c.SendString(strconv.Itoa(count))
})
nonAuthReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
resp, err := app.Test(nonAuthReq)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
authReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
authReq.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err = app.Test(authReq)
require.NoError(t, err)
require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "2", string(body))
resp, err = app.Test(nonAuthReq)
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "1", string(body))
}
func TestCacheAllowsSharedCacheWithAuthorization(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 10 * time.Second}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=60")
return c.SendString("ok")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "ok", string(body))
require.Equal(t, 1, count)
}
func TestCacheAllowsAuthorizationWithRevalidateDirectives(t *testing.T) {
t.Parallel()
tests := []struct {
name string
cacheControl string
expires string
expectedBody string
expectedBody2 string
expectFirst string
expectSecond string
}{
{
name: "must-revalidate",
cacheControl: "must-revalidate, max-age=60",
expectedBody: "ok-1",
expectedBody2: "ok-1",
expectFirst: cacheMiss,
expectSecond: cacheHit,
},
{
name: "proxy-revalidate",
cacheControl: "proxy-revalidate, max-age=60",
expectedBody: "ok-1",
expectedBody2: "ok-1",
expectFirst: cacheMiss,
expectSecond: cacheHit,
},
{
name: "expires header",
cacheControl: "",
expires: time.Now().Add(1 * time.Minute).UTC().Format(http.TimeFormat),
expectedBody: "ok-1",
expectedBody2: "ok-2",
expectFirst: cacheUnreachable,
expectSecond: cacheUnreachable,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 10 * time.Second}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, tt.cacheControl)
if tt.expires != "" {
c.Set(fiber.HeaderExpires, tt.expires)
}
return c.SendString(fmt.Sprintf("ok-%d", count))
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer token")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, tt.expectFirst, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, tt.expectedBody, string(body))
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, tt.expectSecond, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, tt.expectedBody2, string(body))
if tt.expectSecond == cacheHit {
require.Equal(t, 1, count)
} else {
require.Equal(t, 2, count)
}
})
}
}
func TestCacheSeparatesAuthorizationValues(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 10 * time.Second}))
var count int
app.Get("/", func(c fiber.Ctx) error {
count++
c.Set(fiber.HeaderCacheControl, "public, max-age=60")
return c.SendString(fmt.Sprintf("body-%d-%s", count, c.Get(fiber.HeaderAuthorization)))
})
newRequest := func(token string) *http.Request {
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(fiber.HeaderAuthorization, "Bearer "+token)
return req
}
authTokenA := "token-a"
authTokenB := "token-b"
resp, err := app.Test(newRequest(authTokenA))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "body-1-Bearer "+authTokenA, string(body))
resp, err = app.Test(newRequest(authTokenA))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "body-1-Bearer "+authTokenA, string(body))
require.Equal(t, 1, count)
resp, err = app.Test(newRequest(authTokenB))
require.NoError(t, err)
require.Equal(t, cacheMiss, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "body-2-Bearer "+authTokenB, string(body))
require.Equal(t, 2, count)
resp, err = app.Test(newRequest(authTokenB))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "body-2-Bearer "+authTokenB, string(body))
resp, err = app.Test(newRequest(authTokenA))
require.NoError(t, err)
require.Equal(t, cacheHit, resp.Header.Get("X-Cache"))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "body-1-Bearer "+authTokenA, string(body))
require.Equal(t, 2, count)
}
// go test -v -run=^$ -bench=Benchmark_Cache -benchmem -count=4
func Benchmark_Cache(b *testing.B) {
app := fiber.New()
app.Use(New())
app.Get("/demo", func(c fiber.Ctx) error {
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
return c.Status(fiber.StatusTeapot).Send(data)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/demo")
b.ReportAllocs()
for b.Loop() {
h(fctx)
}
require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
require.Greater(b, len(fctx.Response.Body()), 30000)
}
func Benchmark_Cache_Miss(b *testing.B) {
app := fiber.New()
app.Use(New())
app.Get("/*", func(c fiber.Ctx) error {
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
return c.Status(fiber.StatusOK).Send(data)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
b.ReportAllocs()
b.ResetTimer()
var n int
for b.Loop() {
n++
fctx.Request.SetRequestURI("/demo/" + strconv.Itoa(n))
h(fctx)
}
require.Equal(b, fiber.StatusOK, fctx.Response.Header.StatusCode())
require.Greater(b, len(fctx.Response.Body()), 30000)
}
// go test -v -run=^$ -bench=Benchmark_Cache_Storage -benchmem -count=4
func Benchmark_Cache_Storage(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Storage: memory.New(),
}))
app.Get("/demo", func(c fiber.Ctx) error {
data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark
return c.Status(fiber.StatusTeapot).Send(data)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/demo")
b.ReportAllocs()
for b.Loop() {
h(fctx)
}
require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
require.Greater(b, len(fctx.Response.Body()), 30000)
}
func Benchmark_Cache_AdditionalHeaders(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
StoreResponseHeaders: true,
}))
app.Get("/demo", func(c fiber.Ctx) error {
c.Response().Header.Add("X-Foobar", "foobar")
return c.SendStatus(418)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/demo")
b.ReportAllocs()
for b.Loop() {
h(fctx)
}
require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
require.Equal(b, []byte("foobar"), fctx.Response.Header.Peek("X-Foobar"))
}
func Benchmark_Cache_MaxSize(b *testing.B) {
// The benchmark is run with three different MaxSize parameters
// 1) 0: Tracking is disabled = no overhead
// 2) MaxInt32: Enough to store all entries = no removals
// 3) 100: Small size = constant insertions and removals
cases := []uint{0, math.MaxUint32, 100}
names := []string{"Disabled", "Unlim", "LowBounded"}
for i, size := range cases {
b.Run(names[i], func(b *testing.B) {
app := fiber.New()
app.Use(New(Config{MaxBytes: size}))
app.Get("/*", func(c fiber.Ctx) error {
return c.Status(fiber.StatusTeapot).SendString("1")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
b.ReportAllocs()
n := 0
for b.Loop() {
n++
fctx.Request.SetRequestURI(fmt.Sprintf("/%v", n))
h(fctx)
}
require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
})
}
}
func Test_Cache_RevalidationWithMaxBytes(t *testing.T) {
t.Parallel()
t.Run("max-age=0 revalidation removes old entry on storage success", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxBytes: 100,
}))
requestCount := 0
app.Get("/test", func(c fiber.Ctx) error {
requestCount++
c.Set(fiber.HeaderCacheControl, "max-age=60")
return c.SendString(fmt.Sprintf("response-%d", requestCount))
})
// First request - cache the response
req1 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
resp1, err := app.Test(req1)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp1.Header.Get("X-Cache"))
// Request with max-age=0 to force revalidation
req2 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req2.Header.Set(fiber.HeaderCacheControl, "max-age=0")
resp2, err := app.Test(req2)
require.NoError(t, err)
body2, err := io.ReadAll(resp2.Body)
require.NoError(t, err)
require.Equal(t, "response-2", string(body2))
// Next request should serve the NEW cached entry
req3 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
resp3, err := app.Test(req3)
require.NoError(t, err)
require.Equal(t, cacheHit, resp3.Header.Get("X-Cache"))
body3, err := io.ReadAll(resp3.Body)
require.NoError(t, err)
require.Equal(t, "response-2", string(body3), "New entry should be cached")
})
t.Run("min-fresh revalidation with MaxBytes", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxBytes: 100,
}))
requestCount := 0
app.Get("/test", func(c fiber.Ctx) error {
requestCount++
c.Set(fiber.HeaderCacheControl, "max-age=2")
return c.SendString(fmt.Sprintf("response-%d", requestCount))
})
// First request - cache the response
req1 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
resp1, err := app.Test(req1)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp1.Header.Get("X-Cache"))
// Wait a bit so the entry has aged
time.Sleep(1 * time.Second)
// Request with min-fresh that exceeds remaining freshness
req2 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req2.Header.Set(fiber.HeaderCacheControl, "min-fresh=5")
resp2, err := app.Test(req2)
require.NoError(t, err)
body2, err := io.ReadAll(resp2.Body)
require.NoError(t, err)
require.Equal(t, "response-2", string(body2))
// Next request should serve the NEW cached entry
req3 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
resp3, err := app.Test(req3)
require.NoError(t, err)
require.Equal(t, cacheHit, resp3.Header.Get("X-Cache"))
body3, err := io.ReadAll(resp3.Body)
require.NoError(t, err)
require.Equal(t, "response-2", string(body3))
})
t.Run("revalidation respects MaxBytes eviction", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxBytes: 20, // Only room for 2 responses of 10 bytes each
ExpirationGenerator: stableAscendingExpiration(),
}))
app.Get("/*", func(c fiber.Ctx) error {
c.Set(fiber.HeaderCacheControl, "max-age=60")
return c.SendString("1234567890") // 10 bytes
})
// Cache /a and /b
req1 := httptest.NewRequest(fiber.MethodGet, "/a", http.NoBody)
resp1, err := app.Test(req1)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp1.Header.Get("X-Cache"))
req2 := httptest.NewRequest(fiber.MethodGet, "/b", http.NoBody)
resp2, err := app.Test(req2)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp2.Header.Get("X-Cache"))
// Both should be cached
req3 := httptest.NewRequest(fiber.MethodGet, "/a", http.NoBody)
resp3, err := app.Test(req3)
require.NoError(t, err)
require.Equal(t, cacheHit, resp3.Header.Get("X-Cache"))
req4 := httptest.NewRequest(fiber.MethodGet, "/b", http.NoBody)
resp4, err := app.Test(req4)
require.NoError(t, err)
require.Equal(t, cacheHit, resp4.Header.Get("X-Cache"))
// Revalidate /a with max-age=0
req5 := httptest.NewRequest(fiber.MethodGet, "/a", http.NoBody)
req5.Header.Set(fiber.HeaderCacheControl, "max-age=0")
_, err = app.Test(req5)
require.NoError(t, err)
// /a should be revalidated and cached again
req6 := httptest.NewRequest(fiber.MethodGet, "/a", http.NoBody)
resp6, err := app.Test(req6)
require.NoError(t, err)
require.Equal(t, cacheHit, resp6.Header.Get("X-Cache"))
// /b should still be cached (heap accounting should be correct)
req7 := httptest.NewRequest(fiber.MethodGet, "/b", http.NoBody)
resp7, err := app.Test(req7)
require.NoError(t, err)
require.Equal(t, cacheHit, resp7.Header.Get("X-Cache"))
})
t.Run("revalidation with non-cacheable response preserves old entry", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxBytes: 100,
}))
requestCount := 0
app.Get("/test", func(c fiber.Ctx) error {
requestCount++
if requestCount == 1 {
c.Set(fiber.HeaderCacheControl, "max-age=60")
return c.SendString("cacheable")
}
// Second request returns no-store
c.Set(fiber.HeaderCacheControl, "no-store")
return c.SendString("not-cacheable")
})
// First request - cache the response
req1 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
resp1, err := app.Test(req1)
require.NoError(t, err)
require.Equal(t, cacheMiss, resp1.Header.Get("X-Cache"))
body1, err := io.ReadAll(resp1.Body)
require.NoError(t, err)
require.Equal(t, "cacheable", string(body1))
// Request with max-age=0 to force revalidation
// The new response will be no-store (not cacheable)
req2 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req2.Header.Set(fiber.HeaderCacheControl, "max-age=0")
resp2, err := app.Test(req2)
require.NoError(t, err)
body2, err := io.ReadAll(resp2.Body)
require.NoError(t, err)
require.Equal(t, "not-cacheable", string(body2))
// Next request should still serve the OLD cached entry
// because the new response was not cacheable and old entry should remain tracked
req3 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
resp3, err := app.Test(req3)
require.NoError(t, err)
require.Equal(t, cacheHit, resp3.Header.Get("X-Cache"))
body3, err := io.ReadAll(resp3.Body)
require.NoError(t, err)
require.Equal(t, "cacheable", string(body3), "Old entry should still be cached")
})
}
// Test_parseCacheControlDirectives_QuotedStrings tests RFC 9111 Section 5.2 compliance
// for quoted-string values in Cache-Control directives
func Test_parseCacheControlDirectives_QuotedStrings(t *testing.T) {
t.Parallel()
tests := []struct {
name string
expected map[string]string
input string
}{
{
name: "simple quoted value",
input: `community="UCI"`,
expected: map[string]string{
"community": "UCI",
},
},
{
name: "multiple directives with quoted values",
input: `max-age=3600, community="UCI", custom="value"`,
expected: map[string]string{
"max-age": "3600",
"community": "UCI",
"custom": "value",
},
},
{
name: "quoted value with spaces",
input: `custom="value with spaces"`,
expected: map[string]string{
"custom": "value with spaces",
},
},
{
name: "quoted value with escaped quote",
input: `custom="value with \"quotes\""`,
expected: map[string]string{
"custom": `value with "quotes"`,
},
},
{
name: "quoted value with escaped backslash",
input: `custom="value with \\ backslash"`,
expected: map[string]string{
"custom": `value with \ backslash`,
},
},
{
name: "mixed quoted and unquoted values",
input: `max-age=3600, community="UCI", no-cache, custom="test"`,
expected: map[string]string{
"max-age": "3600",
"community": "UCI",
"no-cache": "",
"custom": "test",
},
},
{
name: "quoted empty value",
input: `custom=""`,
expected: map[string]string{
"custom": "",
},
},
{
name: "spaces around quoted value",
input: `custom = "value" , another="test"`,
expected: map[string]string{
"custom": "value",
"another": "test",
},
},
{
name: "unquoted token value",
input: `max-age=3600`,
expected: map[string]string{
"max-age": "3600",
},
},
{
name: "complex mixed case",
input: `max-age=3600, s-maxage=7200, community="UCI", no-store, custom="value with \"escaped\" quotes"`,
expected: map[string]string{
"max-age": "3600",
"s-maxage": "7200",
"community": "UCI",
"no-store": "",
"custom": `value with "escaped" quotes`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := make(map[string]string)
parseCacheControlDirectives([]byte(tt.input), func(key, value []byte) {
result[string(key)] = string(value)
})
require.Equal(t, tt.expected, result)
})
}
}
// Test_unquoteCacheDirective tests the unquoting logic for quoted-string values
func Test_unquoteCacheDirective(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input []byte
expected []byte
}{
{
name: "simple quoted string",
input: []byte(`"value"`),
expected: []byte("value"),
},
{
name: "empty quoted string",
input: []byte(`""`),
expected: []byte(""),
},
{
name: "quoted string with spaces",
input: []byte(`"value with spaces"`),
expected: []byte("value with spaces"),
},
{
name: "quoted string with escaped quote",
input: []byte(`"value with \"quote\""`),
expected: []byte(`value with "quote"`),
},
{
name: "quoted string with escaped backslash",
input: []byte(`"value with \\ backslash"`),
expected: []byte(`value with \ backslash`),
},
{
name: "quoted string with multiple escapes",
input: []byte(`"a\"b\\c\"d"`),
expected: []byte(`a"b\c"d`),
},
{
name: "too short input",
input: []byte(`"`),
expected: []byte(`"`),
},
{
name: "empty input",
input: []byte(``),
expected: []byte(``),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := unquoteCacheDirective(tt.input)
require.Equal(t, tt.expected, result)
})
}
}
// Test_Cache_MaxBytes_InsufficientSpace tests the "insufficient space" error path
// when an entry is larger than MaxBytes, ensuring such entries are treated as unreachable
func Test_Cache_MaxBytes_InsufficientSpace(t *testing.T) {
t.Parallel()
t.Run("entry larger than MaxBytes with empty cache", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxBytes: 10, // Very small cache
Expiration: 1 * time.Hour,
}))
app.Get("/large", func(c fiber.Ctx) error {
// Return data larger than MaxBytes
return c.Send(make([]byte, 20))
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/large", http.NoBody))
require.NoError(t, err)
// Should be unreachable because entry is too large
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
})
t.Run("entry larger than MaxBytes after eviction", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxBytes: 15,
ExpirationGenerator: stableAscendingExpiration(),
}))
app.Get("/*", func(c fiber.Ctx) error {
path := c.Path()
if path == "/small" {
return c.Send(make([]byte, 5))
}
return c.Send(make([]byte, 20))
})
// Cache a small entry first
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/small", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
// Try to cache a large entry - should return unreachable since it won't fit even after eviction
rsp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/large", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
})
}
func Test_Cache_MaxBytes_DeletionFailureRestoresTracking(t *testing.T) {
t.Parallel()
storage := newFailingCacheStorage()
app := fiber.New()
app.Use(New(Config{
MaxBytes: 4,
Expiration: 1 * time.Hour,
Storage: storage,
}))
app.Get("/:name", func(c fiber.Ctx) error {
return c.SendString("data")
})
// Seed the cache with a single entry
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/first", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
var storedKeys []string
storage.mu.Lock()
for key := range storage.data {
storedKeys = append(storedKeys, key)
if strings.Contains(key, "/first") {
storage.errs["del|"+key] = errors.New("delete failed")
}
}
storage.mu.Unlock()
t.Logf("stored keys after first cache: %v", storedKeys)
// Next request triggers eviction; deletion failure should surface an error
rsp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/second", http.NoBody))
require.NoError(t, err)
body, err := io.ReadAll(rsp.Body)
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, rsp.StatusCode)
require.Contains(t, string(body), "failed to delete key")
require.NoError(t, rsp.Body.Close())
var remainingKeys []string
storage.mu.RLock()
for key := range storage.data {
remainingKeys = append(remainingKeys, key)
}
storage.mu.RUnlock()
t.Logf("stored keys after deletion failure: %v", remainingKeys)
storage.mu.Lock()
storage.errs = make(map[string]error)
storage.mu.Unlock()
// Another request should succeed and be cacheable after restoring heap tracking
rsp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/third", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
require.NoError(t, rsp.Body.Close())
}
// Test_Cache_MaxBytes_ConcurrencyAndRaceConditions tests that the race condition fix works correctly
// under concurrent load, verifying that storedBytes never exceeds MaxBytes even with multiple
// goroutines making simultaneous requests
func Test_Cache_MaxBytes_ConcurrencyAndRaceConditions(t *testing.T) {
t.Parallel()
t.Run("concurrent requests with MaxBytes limit", func(t *testing.T) {
t.Parallel()
app := fiber.New()
const maxBytes = uint(1000)
const numGoroutines = 20
const requestsPerGoroutine = 5
app.Use(New(Config{
MaxBytes: maxBytes,
Expiration: 10 * time.Second,
}))
app.Get("/*", func(c fiber.Ctx) error {
// Return data that will fill up the cache
return c.Send(make([]byte, 50))
})
// Launch multiple goroutines making concurrent requests
var wg sync.WaitGroup
errChan := make(chan error, numGoroutines*requestsPerGoroutine)
for i := 0; i < numGoroutines; i++ {
id := i
wg.Add(1) //nolint:revive // Standard WaitGroup pattern is appropriate here
go func() {
defer wg.Done()
for j := 0; j < requestsPerGoroutine; j++ {
path := fmt.Sprintf("/test-%d-%d", id, j)
req := httptest.NewRequest(fiber.MethodGet, path, http.NoBody)
_, err := app.Test(req)
if err != nil {
errChan <- err
}
}
}()
}
wg.Wait()
close(errChan)
// Check for errors
for err := range errChan {
require.NoError(t, err, "concurrent request failed")
}
// The test passes if no errors occurred and no race conditions were detected by -race flag
})
t.Run("concurrent requests near capacity triggers eviction", func(t *testing.T) {
t.Parallel()
app := fiber.New()
const maxBytes = uint(200)
const numRequests = 10
app.Use(New(Config{
MaxBytes: maxBytes,
Expiration: 10 * time.Second,
}))
app.Get("/*", func(c fiber.Ctx) error {
// Each response is about 50 bytes, so we'll exceed capacity
return c.Send(make([]byte, 50))
})
// Make concurrent requests that will trigger evictions
var wg sync.WaitGroup
for i := 0; i < numRequests; i++ {
id := i
wg.Add(1) //nolint:revive // Standard WaitGroup pattern is appropriate here
go func() {
defer wg.Done()
path := fmt.Sprintf("/item-%d", id)
req := httptest.NewRequest(fiber.MethodGet, path, http.NoBody)
_, err := app.Test(req)
if err != nil {
t.Logf("request error: %v", err)
}
}()
}
wg.Wait()
// Test passes if no race conditions or panics occurred
// The -race flag will detect any remaining race conditions
})
}
// Test_Cache_HelperFunctions tests various helper functions for better coverage
func Test_Cache_HelperFunctions(t *testing.T) {
t.Parallel()
t.Run("parseHTTPDate empty", func(t *testing.T) {
t.Parallel()
result, ok := parseHTTPDate([]byte{})
require.False(t, ok)
require.Equal(t, uint64(0), result)
})
t.Run("parseHTTPDate invalid", func(t *testing.T) {
t.Parallel()
result, ok := parseHTTPDate([]byte("invalid"))
require.False(t, ok)
require.Equal(t, uint64(0), result)
})
t.Run("parseHTTPDate valid", func(t *testing.T) {
t.Parallel()
result, ok := parseHTTPDate([]byte("Mon, 02 Jan 2006 15:04:05 GMT"))
require.True(t, ok)
require.Positive(t, result)
})
t.Run("safeUnixSeconds negative", func(t *testing.T) {
t.Parallel()
result := safeUnixSeconds(time.Unix(-1, 0))
require.Equal(t, uint64(0), result)
})
t.Run("safeUnixSeconds positive", func(t *testing.T) {
t.Parallel()
result := safeUnixSeconds(time.Unix(1234567890, 0))
require.Equal(t, uint64(1234567890), result)
})
t.Run("remainingFreshness nil", func(t *testing.T) {
t.Parallel()
result := remainingFreshness(nil, 100)
require.Equal(t, uint64(0), result)
})
t.Run("remainingFreshness zero exp", func(t *testing.T) {
t.Parallel()
e := &item{exp: 0}
result := remainingFreshness(e, 100)
require.Equal(t, uint64(0), result)
})
t.Run("remainingFreshness expired", func(t *testing.T) {
t.Parallel()
e := &item{exp: 100}
result := remainingFreshness(e, 200)
require.Equal(t, uint64(0), result)
})
t.Run("remainingFreshness valid", func(t *testing.T) {
t.Parallel()
e := &item{exp: 200}
result := remainingFreshness(e, 100)
require.Equal(t, uint64(100), result)
})
t.Run("lookupCachedHeader not found", func(t *testing.T) {
t.Parallel()
headers := []cachedHeader{{key: []byte("Content-Type"), value: []byte("text/html")}}
value, found := lookupCachedHeader(headers, "Authorization")
require.False(t, found)
require.Nil(t, value)
})
t.Run("lookupCachedHeader case insensitive", func(t *testing.T) {
t.Parallel()
headers := []cachedHeader{{key: []byte("Authorization"), value: []byte("Bearer token")}}
value, found := lookupCachedHeader(headers, "authorization")
require.True(t, found)
require.Equal(t, []byte("Bearer token"), value)
})
t.Run("secondsToDuration zero", func(t *testing.T) {
t.Parallel()
result := secondsToDuration(0)
require.Equal(t, time.Duration(0), result)
})
t.Run("secondsToDuration large", func(t *testing.T) {
t.Parallel()
result := secondsToDuration(9223372036)
require.Greater(t, result, time.Duration(0))
})
t.Run("secondsToTime zero", func(t *testing.T) {
t.Parallel()
result := secondsToTime(0)
require.Equal(t, time.Unix(0, 0).UTC(), result)
})
t.Run("secondsToTime value", func(t *testing.T) {
t.Parallel()
result := secondsToTime(1234567890)
require.Equal(t, time.Unix(1234567890, 0).UTC(), result)
})
t.Run("isHeuristicFreshness short age", func(t *testing.T) {
t.Parallel()
cfg := &Config{Expiration: 1 * time.Hour}
e := &item{cacheControl: []byte("public")}
result := isHeuristicFreshness(e, cfg, 3600)
require.False(t, result)
})
t.Run("isHeuristicFreshness with expires", func(t *testing.T) {
t.Parallel()
cfg := &Config{Expiration: 1 * time.Hour}
e := &item{cacheControl: []byte("public"), expires: []byte("Wed, 21 Oct 2015 07:28:00 GMT")}
result := isHeuristicFreshness(e, cfg, uint64(25*time.Hour/time.Second))
require.False(t, result)
})
t.Run("isHeuristicFreshness true", func(t *testing.T) {
t.Parallel()
cfg := &Config{Expiration: 1 * time.Hour}
e := &item{cacheControl: []byte("public")}
result := isHeuristicFreshness(e, cfg, uint64(25*time.Hour/time.Second))
require.True(t, result)
})
t.Run("cacheBodyFetchError miss", func(t *testing.T) {
t.Parallel()
mask := func(_ string) string { return "***" }
err := cacheBodyFetchError(mask, "key", errCacheMiss)
require.Error(t, err)
require.Contains(t, err.Error(), "no cached body")
})
t.Run("cacheBodyFetchError other", func(t *testing.T) {
t.Parallel()
mask := func(_ string) string { return "***" }
originalErr := errors.New("storage error")
err := cacheBodyFetchError(mask, "key", originalErr)
require.Equal(t, originalErr, err)
})
}
// Test_Cache_VaryAndAuth tests vary and auth functionality
func Test_Cache_VaryAndAuth(t *testing.T) {
t.Parallel()
t.Run("storeVaryManifest failure", func(t *testing.T) {
t.Parallel()
storage := newFailingCacheStorage()
storage.errs["set|manifest"] = errors.New("storage fail")
manager := &manager{storage: storage}
err := storeVaryManifest(context.Background(), manager, "manifest", []string{"Accept"}, 3600*time.Second)
require.Error(t, err)
})
t.Run("loadVaryManifest not found", func(t *testing.T) {
t.Parallel()
storage := newFailingCacheStorage()
manager := &manager{storage: storage}
varyNames, found, err := loadVaryManifest(context.Background(), manager, "nonexistent")
require.NoError(t, err)
require.False(t, found)
require.Nil(t, varyNames)
})
t.Run("vary with multiple headers", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Vary", "Accept, Accept-Encoding")
c.Response().Header.Set("Cache-Control", "max-age=3600")
return c.SendString("test")
})
req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Encoding", "gzip")
rsp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
req2 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req2.Header.Set("Accept", "application/json")
req2.Header.Set("Accept-Encoding", "gzip")
rsp2, err := app.Test(req2)
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("auth with must-revalidate", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "must-revalidate, max-age=3600")
return c.SendString("content")
})
req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req.Header.Set("Authorization", "Bearer token1")
rsp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
req2 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req2.Header.Set("Authorization", "Bearer token1")
rsp2, err := app.Test(req2)
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
}
// Test_Cache_DateAndCacheControl tests date parsing and cache control
func Test_Cache_DateAndCacheControl(t *testing.T) {
t.Parallel()
t.Run("date header parsing", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Date", "Mon, 02 Jan 2006 15:04:05 GMT")
c.Response().Header.Set("Cache-Control", "max-age=3600")
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
})
t.Run("invalid date header", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Date", "invalid")
c.Response().Header.Set("Cache-Control", "max-age=3600")
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
})
t.Run("cache control with quoted values", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", `max-age=3600, ext="value, with, commas"`)
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
})
t.Run("cache control with spaces", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=3600 , public , must-revalidate")
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
})
}
// Test_Cache_CacheControlCombinations tests common cache control directive combinations
func Test_Cache_CacheControlCombinations(t *testing.T) {
t.Parallel()
t.Run("max-age with public", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "public, max-age=3600")
return c.SendString("public content")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("max-age with private", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "private, max-age=3600")
return c.SendString("private content")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
})
t.Run("s-maxage overrides max-age", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "public, max-age=60, s-maxage=3600")
return c.SendString("content")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("no-store prevents caching", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "no-store")
return c.SendString("no store content")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp2.Header.Get("X-Cache"))
})
t.Run("no-cache with etag", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "no-cache")
c.Response().Header.Set("ETag", `"123456"`)
return c.SendString("no-cache content")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
})
t.Run("must-revalidate with max-age", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "must-revalidate, max-age=3600")
return c.SendString("must revalidate content")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("proxy-revalidate with max-age", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "public, proxy-revalidate, max-age=3600")
return c.SendString("proxy revalidate content")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("immutable with max-age", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "public, max-age=31536000, immutable")
return c.SendString("immutable content")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("max-age=0 with must-revalidate", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=0, must-revalidate")
return c.SendString("always revalidate")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
})
t.Run("public with no explicit max-age", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "public")
return c.SendString("public no max-age")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("multiple cache directives with extensions", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", `public, max-age=3600, custom="value"`)
return c.SendString("content")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("private overrides public", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "public, private, max-age=3600")
return c.SendString("conflicting directives")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
})
t.Run("stale-while-revalidate with max-age", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=60, stale-while-revalidate=120")
return c.SendString("stale while revalidate")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("stale-if-error with max-age", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=60, stale-if-error=3600")
return c.SendString("stale if error")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
}
// Test_Cache_RequestResponseDirectives tests caching behavior with various request/response cache-control directives
func Test_Cache_RequestResponseDirectives(t *testing.T) {
t.Parallel()
t.Run("negative expiration skips caching", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: -1 * time.Second}))
app.Get("/test", func(c fiber.Ctx) error {
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.NotEqual(t, cacheMiss, rsp.Header.Get("X-Cache"))
require.NotEqual(t, cacheHit, rsp.Header.Get("X-Cache"))
})
t.Run("request with no-store directive", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
return c.SendString("test")
})
req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req.Header.Set("Cache-Control", "no-store")
rsp, err := app.Test(req)
require.NoError(t, err)
require.NotEqual(t, cacheMiss, rsp.Header.Get("X-Cache"))
})
t.Run("request with pragma no-cache", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=3600")
return c.SendString("test")
})
req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req.Header.Set("Pragma", "no-cache")
rsp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
})
t.Run("method not in allowed methods list", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 1 * time.Hour,
Methods: []string{fiber.MethodGet},
}))
app.Post("/test", func(c fiber.Ctx) error {
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
})
t.Run("request with min-fresh directive", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=60")
return c.SendString("test")
})
// First request to cache
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
// Second request with min-fresh that's too high
req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req.Header.Set("Cache-Control", "min-fresh=120")
rsp, err = app.Test(req)
require.NoError(t, err)
// Should be a miss because min-fresh requirement not met
cacheStatus := rsp.Header.Get("X-Cache")
require.Contains(t, []string{cacheMiss, cacheUnreachable}, cacheStatus, "min-fresh requirement should prevent cache hit")
})
t.Run("request with max-age=0 directive", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=3600")
return c.SendString("test")
})
// First request to cache
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
// Second request with max-age=0
req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req.Header.Set("Cache-Control", "max-age=0")
rsp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
})
t.Run("request with max-stale directive", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Second}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=1")
return c.SendString("test")
})
// First request to cache
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
// Wait for it to become stale
time.Sleep(2 * time.Second)
// Request with max-stale to accept stale content
req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req.Header.Set("Cache-Control", "max-stale=60")
rsp, err = app.Test(req)
require.NoError(t, err)
// max-stale should allow serving stale content
cacheStatus := rsp.Header.Get("X-Cache")
// Should be either a hit (if stale is served) or miss (if revalidated)
require.Contains(t, []string{cacheHit, cacheMiss, "stale"}, cacheStatus, "max-stale should allow stale content or revalidate")
})
t.Run("response with expires header", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
futureTime := time.Now().Add(1 * time.Hour).Format(time.RFC1123)
c.Response().Header.Set("Expires", futureTime)
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("response with age header", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=3600")
c.Response().Header.Set("Age", "30")
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("custom key generator", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 1 * time.Hour,
KeyGenerator: func(c fiber.Ctx) string {
return "custom-" + c.Path()
},
}))
app.Get("/test", func(c fiber.Ctx) error {
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("response with warning header", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Second}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=1")
return c.SendString("test")
})
// Cache the response
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
// Wait for it to become stale
time.Sleep(2 * time.Second)
// Request again - should get stale warning or revalidate
rsp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
// Check that either cache miss (revalidation) or warning header is present
cacheStatus := rsp.Header.Get("X-Cache")
warningHeader := rsp.Header.Get("Warning")
require.True(t, cacheStatus == cacheMiss || warningHeader != "", "stale response should either revalidate or have warning header")
})
t.Run("external storage with body key", func(t *testing.T) {
t.Parallel()
storage := newFailingCacheStorage()
app := fiber.New()
app.Use(New(Config{
Expiration: 1 * time.Hour,
Storage: storage,
}))
app.Get("/test", func(c fiber.Ctx) error {
return c.SendString("test content")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
// Verify body key is stored
hasBodyKey := false
storage.mu.RLock()
for k := range storage.data {
if strings.Contains(k, "_body") {
hasBodyKey = true
break
}
}
storage.mu.RUnlock()
require.True(t, hasBodyKey)
})
t.Run("only-if-cached with cache miss", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
return c.SendString("test")
})
req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req.Header.Set("Cache-Control", "only-if-cached")
rsp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusGatewayTimeout, rsp.StatusCode)
})
t.Run("only-if-cached with cache hit", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=3600")
return c.SendString("test")
})
// First request to cache
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
// Second request with only-if-cached
req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req.Header.Set("Cache-Control", "only-if-cached")
rsp2, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("cache control with uppercase directives", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "PUBLIC, MAX-AGE=3600")
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
}
// Test_Cache_ConfigurationAndResponseHandling tests cache behavior for specific configuration and response edge cases.
func Test_Cache_ConfigurationAndResponseHandling(t *testing.T) {
t.Parallel()
t.Run("response with Vary star prevents caching", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Vary", "*")
c.Response().Header.Set("Cache-Control", "max-age=3600")
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
})
t.Run("next function prevents caching", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 1 * time.Hour,
Next: func(c fiber.Ctx) bool {
return c.Path() == "/skip"
},
}))
app.Get("/skip", func(c fiber.Ctx) error {
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/skip", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
})
t.Run("non-cacheable status code", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
return c.Status(fiber.StatusCreated).SendString("created")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
})
t.Run("body larger than MaxBytes", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 1 * time.Hour,
MaxBytes: 10,
}))
app.Get("/test", func(c fiber.Ctx) error {
return c.Send(make([]byte, 100))
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
})
t.Run("authorization without shared cache directives", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=3600")
return c.SendString("test")
})
req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req.Header.Set("Authorization", "******")
rsp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache"))
})
t.Run("disable cache control header generation", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 1 * time.Hour,
DisableCacheControl: true,
}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=3600")
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
})
t.Run("disable value redaction", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 1 * time.Hour,
DisableValueRedaction: true,
}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=3600")
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
})
t.Run("response with ETag header", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=3600")
c.Response().Header.Set("ETag", `"abc123"`)
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
require.Equal(t, `"abc123"`, rsp2.Header.Get("ETag"))
})
t.Run("response with Content-Encoding header", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=3600")
c.Response().Header.Set("Content-Encoding", "gzip")
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
require.Equal(t, "gzip", rsp2.Header.Get("Content-Encoding"))
})
t.Run("response with custom headers preserved", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Expiration: 1 * time.Hour,
StoreResponseHeaders: true,
}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=3600")
c.Response().Header.Set("X-Custom-Header", "custom-value")
return c.SendString("test")
})
rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache"))
require.Equal(t, "custom-value", rsp2.Header.Get("X-Custom-Header"))
})
t.Run("revalidation scenario with cache miss", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
c.Response().Header.Set("Cache-Control", "max-age=3600")
return c.SendString("test")
})
// Request with no-cache forces revalidation
req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req.Header.Set("Cache-Control", "no-cache")
rsp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
})
t.Run("delete vary manifest on no-cache response", func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
// First request creates vary manifest
if c.Query("first") != "" {
c.Response().Header.Set("Vary", "Accept")
c.Response().Header.Set("Cache-Control", "max-age=3600")
} else {
// Second request returns no-cache to delete manifest
c.Response().Header.Set("Cache-Control", "no-cache")
}
return c.SendString("test")
})
req := httptest.NewRequest(fiber.MethodGet, "/test?first=true", http.NoBody)
req.Header.Set("Accept", "application/json")
rsp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
// Second request without Vary should delete manifest
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheUnreachable, rsp2.Header.Get("X-Cache"))
})
t.Run("vary manifest deletion on different vary response", func(t *testing.T) {
t.Parallel()
app := fiber.New()
var counter atomic.Int32
app.Use(New(Config{Expiration: 1 * time.Hour}))
app.Get("/test", func(c fiber.Ctx) error {
if counter.Add(1) == 1 {
c.Response().Header.Set("Vary", "Accept")
}
// Second response has no Vary header - should delete manifest
c.Response().Header.Set("Cache-Control", "max-age=3600")
return c.SendString("test")
})
req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
req.Header.Set("Accept", "application/json")
rsp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache"))
// Second request - different vary behavior
rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, cacheMiss, rsp2.Header.Get("X-Cache"))
})
}
================================================
FILE: middleware/cache/config.go
================================================
package cache
import (
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
)
// Config defines the config for middleware.
type Config struct {
// Storage is used to store the state of the middleware
//
// Default: an in-memory store for this process only
Storage fiber.Storage
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// CacheInvalidator defines a function to invalidate the cache when returned true
//
// Optional. Default: nil
CacheInvalidator func(fiber.Ctx) bool
// Key allows you to generate custom keys, by default c.Path() is used
//
// Default: func(c fiber.Ctx) string {
// return utils.CopyString(c.Path())
// }
KeyGenerator func(fiber.Ctx) string
// ExpirationGenerator allows you to generate a custom expiration per request.
// If nil, the Expiration value is used.
//
// Default: nil
ExpirationGenerator func(fiber.Ctx, *Config) time.Duration
// CacheHeader header on response header, indicate cache status, with the following possible return value
//
// hit, miss, unreachable
//
// Optional. Default: X-Cache
CacheHeader string
// You can specify HTTP methods to cache.
// The middleware just caches the routes of its methods in this slice.
//
// Default: []string{fiber.MethodGet, fiber.MethodHead}
Methods []string
// Expiration is the time that a cached response will live
//
// Optional. Default: 5 * time.Minute
Expiration time.Duration
// Max number of bytes of response bodies simultaneously stored in cache. When limit is reached,
// entries with the nearest expiration are deleted to make room for new.
// 0 means no limit
//
// Optional. Default: 1 * 1024 * 1024
MaxBytes uint
// DisableValueRedaction turns off masking cache keys in logs and error messages when set to true.
//
// Optional. Default: false
DisableValueRedaction bool
// DisableCacheControl disables client side caching if set to true
//
// Optional. Default: false
DisableCacheControl bool
// StoreResponseHeaders allows you to store additional headers generated by
// next middlewares and handlers.
//
// Default: false
StoreResponseHeaders bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Expiration: 5 * time.Minute,
CacheHeader: "X-Cache",
DisableCacheControl: false,
CacheInvalidator: nil,
DisableValueRedaction: false,
KeyGenerator: func(c fiber.Ctx) string {
return utils.CopyString(c.Path())
},
ExpirationGenerator: nil,
StoreResponseHeaders: false,
Storage: nil,
MaxBytes: 1 * 1024 * 1024,
Methods: []string{fiber.MethodGet, fiber.MethodHead},
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if int(cfg.Expiration.Seconds()) == 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.CacheHeader == "" {
cfg.CacheHeader = ConfigDefault.CacheHeader
}
if cfg.KeyGenerator == nil {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}
if len(cfg.Methods) == 0 {
cfg.Methods = ConfigDefault.Methods
}
return cfg
}
================================================
FILE: middleware/cache/heap.go
================================================
package cache
import (
"container/heap"
)
type heapEntry struct {
key string
exp uint64
bytes uint
idx int
}
// indexedHeap is a regular min-heap that allows finding
// elements in constant time. It does so by handing out special indices
// and tracking entry movement.
//
// indexedHeap is used for quickly finding entries with the lowest
// expiration timestamp and deleting arbitrary entries.
type indexedHeap struct {
// Slice the heap is built on
entries []heapEntry
// Mapping "index" to position in heap slice
indices []int
// Max index handed out
maxidx int
}
// Len implements heap.Interface by reporting the number of entries in the heap.
func (h indexedHeap) Len() int {
return len(h.entries)
}
// Less implements heap.Interface and orders entries by expiration time.
func (h indexedHeap) Less(i, j int) bool {
return h.entries[i].exp < h.entries[j].exp
}
// Swap implements heap.Interface and swaps the entries at the provided indices.
func (h indexedHeap) Swap(i, j int) {
h.entries[i], h.entries[j] = h.entries[j], h.entries[i]
h.indices[h.entries[i].idx] = i
h.indices[h.entries[j].idx] = j
}
// Push implements heap.Interface and inserts a new entry into the heap.
func (h *indexedHeap) Push(x any) {
h.pushInternal(x.(heapEntry)) //nolint:forcetypeassert,errcheck // Forced type assertion required to implement the heap.Interface interface
}
// Pop implements heap.Interface and removes the last entry from the heap.
func (h *indexedHeap) Pop() any {
n := len(h.entries)
h.entries = h.entries[0 : n-1]
return h.entries[0:n][n-1]
}
func (h *indexedHeap) pushInternal(entry heapEntry) {
h.indices[entry.idx] = len(h.entries)
h.entries = append(h.entries, entry)
}
// Returns index to track entry
func (h *indexedHeap) put(key string, exp uint64, bytes uint) int {
idx := 0
if len(h.entries) < h.maxidx {
// Steal index from previously removed entry
// capacity > size is guaranteed
n := len(h.entries)
idx = h.entries[:n+1][n].idx
} else {
idx = h.maxidx
h.maxidx++
h.indices = append(h.indices, idx)
}
// Push manually to avoid allocation
h.pushInternal(heapEntry{
key: key, exp: exp, idx: idx, bytes: bytes,
})
heap.Fix(h, h.Len()-1)
return idx
}
func (h *indexedHeap) removeInternal(realIdx int) (key string, size uint) { //nolint:nonamedreturns // gocritic unnamedResult prefers named key and size when removing heap entries
x := heap.Remove(h, realIdx).(heapEntry) //nolint:forcetypeassert,errcheck // Forced type assertion required to implement the heap.Interface interface
return x.key, x.bytes
}
// Remove entry by index
func (h *indexedHeap) remove(idx int) (key string, size uint) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming returned key and size pair
return h.removeInternal(h.indices[idx])
}
// Remove entry with lowest expiration time
func (h *indexedHeap) removeFirst() (key string, size uint) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming returned key and size pair
return h.removeInternal(0)
}
================================================
FILE: middleware/cache/manager.go
================================================
package cache
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/memory"
)
// msgp -file="manager.go" -o="manager_msgp.go" -tests=true -unexported
// Default slice limits are sized for cache payloads, with tighter field caps below.
//
//go:generate msgp -o=manager_msgp.go -tests=true -unexported
//nolint:revive // msgp requires tags on unexported fields for limit enforcement.
type item struct {
headers []cachedHeader `msg:",limit=1024"` // Typical HTTP header count stays well below this.
body []byte // Cache bodies are bounded by storage policy, not msgp limits.
ctype []byte `msg:",limit=256"` // Content-Type values are short per RFCs.
cencoding []byte `msg:",limit=128"` // Content-Encoding is typically a short token.
cacheControl []byte `msg:",limit=2048"` // Cache-Control directives are bounded.
expires []byte `msg:",limit=128"` // Expires is a short HTTP-date string.
etag []byte `msg:",limit=256"` // ETags are small tokens/quoted strings.
date uint64
status int
age uint64
exp uint64
ttl uint64
forceRevalidate bool
revalidate bool
shareable bool
private bool
// used for finding the item in an indexed heap
heapidx int
}
//nolint:revive // msgp requires tags on unexported fields for limit enforcement.
type cachedHeader struct {
key []byte `msg:",limit=512"` // Header names are small.
value []byte `msg:",limit=16384"` // Header values are bounded to reasonable sizes.
}
//msgp:ignore manager
type manager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
redactKeys bool
}
const redactedKey = "[redacted]"
var errCacheMiss = errors.New("cache: miss")
func newManager(storage fiber.Storage, redactKeys bool) *manager {
// Create new storage handler
manager := &manager{
pool: sync.Pool{
New: func() any {
return new(item)
},
},
redactKeys: redactKeys,
}
if storage != nil {
// Use provided storage if provided
manager.storage = storage
} else {
// Fallback to memory storage
manager.memory = memory.New()
}
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
// don't release item if we using in-memory storage
if m.storage == nil {
return
}
e.body = nil
e.cacheControl = nil
e.expires = nil
e.etag = nil
e.ctype = nil
e.cencoding = nil
e.date = 0
e.status = 0
e.age = 0
e.exp = 0
e.ttl = 0
e.forceRevalidate = false
e.revalidate = false
e.headers = nil
e.shareable = false
e.private = false
e.heapidx = 0
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(ctx context.Context, key string) (*item, error) {
if m.storage != nil {
raw, err := m.storage.GetWithContext(ctx, key)
if err != nil {
return nil, fmt.Errorf("cache: failed to get key %q from storage: %w", m.logKey(key), err)
}
if raw == nil {
return nil, errCacheMiss
}
it := m.acquire()
if _, err := it.UnmarshalMsg(raw); err != nil {
m.release(it)
return nil, fmt.Errorf("cache: failed to unmarshal key %q: %w", m.logKey(key), err)
}
return it, nil
}
if value := m.memory.Get(key); value != nil {
it, ok := value.(*item)
if !ok {
return nil, fmt.Errorf("cache: unexpected entry type %T for key %q", value, m.logKey(key))
}
return it, nil
}
return nil, errCacheMiss
}
// get raw data from storage or memory
func (m *manager) getRaw(ctx context.Context, key string) ([]byte, error) {
if m.storage != nil {
raw, err := m.storage.GetWithContext(ctx, key)
if err != nil {
return nil, fmt.Errorf("cache: failed to get raw key %q from storage: %w", m.logKey(key), err)
}
if raw == nil {
return nil, errCacheMiss
}
return raw, nil
}
if value := m.memory.Get(key); value != nil {
raw, ok := value.([]byte)
if !ok {
return nil, fmt.Errorf("cache: unexpected raw entry type %T for key %q", value, m.logKey(key))
}
return raw, nil
}
return nil, errCacheMiss
}
// set data to storage or memory
func (m *manager) set(ctx context.Context, key string, it *item, exp time.Duration) error {
if m.storage != nil {
raw, err := it.MarshalMsg(nil)
if err != nil {
m.release(it)
return fmt.Errorf("cache: failed to marshal key %q: %w", m.logKey(key), err)
}
if err := m.storage.SetWithContext(ctx, key, raw, exp); err != nil {
m.release(it)
return fmt.Errorf("cache: failed to store key %q: %w", m.logKey(key), err)
}
m.release(it)
return nil
}
m.memory.Set(key, it, exp)
return nil
}
// set data to storage or memory
func (m *manager) setRaw(ctx context.Context, key string, raw []byte, exp time.Duration) error {
if m.storage != nil {
if err := m.storage.SetWithContext(ctx, key, raw, exp); err != nil {
return fmt.Errorf("cache: failed to store raw key %q: %w", m.logKey(key), err)
}
return nil
}
m.memory.Set(key, raw, exp)
return nil
}
// delete data from storage or memory
func (m *manager) del(ctx context.Context, key string) error {
if m.storage != nil {
if err := m.storage.DeleteWithContext(ctx, key); err != nil {
return fmt.Errorf("cache: failed to delete key %q: %w", m.logKey(key), err)
}
return nil
}
m.memory.Delete(key)
return nil
}
func (m *manager) logKey(key string) string {
if m.redactKeys {
return redactedKey
}
return key
}
================================================
FILE: middleware/cache/manager_msgp.go
================================================
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
package cache
import (
"github.com/tinylib/msgp/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *cachedHeader) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "key":
z.key, err = dc.ReadBytesLimit(z.key, 512)
if err == nil && z.key == nil {
z.key = []byte{}
}
if err != nil {
err = msgp.WrapError(err, "key")
return
}
case "value":
z.value, err = dc.ReadBytesLimit(z.value, 16384)
if err == nil && z.value == nil {
z.value = []byte{}
}
if err != nil {
err = msgp.WrapError(err, "value")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z *cachedHeader) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 2
// write "key"
err = en.Append(0x82, 0xa3, 0x6b, 0x65, 0x79)
if err != nil {
return
}
err = en.WriteBytes(z.key)
if err != nil {
err = msgp.WrapError(err, "key")
return
}
// write "value"
err = en.Append(0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65)
if err != nil {
return
}
err = en.WriteBytes(z.value)
if err != nil {
err = msgp.WrapError(err, "value")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z *cachedHeader) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 2
// string "key"
o = append(o, 0x82, 0xa3, 0x6b, 0x65, 0x79)
o = msgp.AppendBytes(o, z.key)
// string "value"
o = append(o, 0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65)
o = msgp.AppendBytes(o, z.value)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *cachedHeader) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "key":
var zb0002 uint32
zb0002, bts, err = msgp.ReadBytesHeader(bts)
if err != nil {
err = msgp.WrapError(err, "key")
return
}
if zb0002 > 512 {
err = msgp.ErrLimitExceeded
return
}
if z.key == nil || uint32(cap(z.key)) < zb0002 {
z.key = make([]byte, zb0002)
} else {
z.key = z.key[:zb0002]
}
if uint32(len(bts)) < zb0002 {
err = msgp.ErrShortBytes
return
}
copy(z.key, bts[:zb0002])
bts = bts[zb0002:]
case "value":
var zb0003 uint32
zb0003, bts, err = msgp.ReadBytesHeader(bts)
if err != nil {
err = msgp.WrapError(err, "value")
return
}
if zb0003 > 16384 {
err = msgp.ErrLimitExceeded
return
}
if z.value == nil || uint32(cap(z.value)) < zb0003 {
z.value = make([]byte, zb0003)
} else {
z.value = z.value[:zb0003]
}
if uint32(len(bts)) < zb0003 {
err = msgp.ErrShortBytes
return
}
copy(z.value, bts[:zb0003])
bts = bts[zb0003:]
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *cachedHeader) Msgsize() (s int) {
s = 1 + 4 + msgp.BytesPrefixSize + len(z.key) + 6 + msgp.BytesPrefixSize + len(z.value)
return
}
// DecodeMsg implements msgp.Decodable
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "headers":
var zb0002 uint32
zb0002, err = dc.ReadArrayHeader()
if err != nil {
err = msgp.WrapError(err, "headers")
return
}
if zb0002 > 1024 {
err = msgp.ErrLimitExceeded
return
}
if cap(z.headers) >= int(zb0002) {
z.headers = (z.headers)[:zb0002]
} else {
z.headers = make([]cachedHeader, zb0002)
}
for za0001 := range z.headers {
var zb0003 uint32
zb0003, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err, "headers", za0001)
return
}
if zb0003 > 1024 {
err = msgp.ErrLimitExceeded
return
}
for zb0003 > 0 {
zb0003--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err, "headers", za0001)
return
}
switch msgp.UnsafeString(field) {
case "key":
z.headers[za0001].key, err = dc.ReadBytesLimit(z.headers[za0001].key, 512)
if err == nil && z.headers[za0001].key == nil {
z.headers[za0001].key = []byte{}
}
if err != nil {
err = msgp.WrapError(err, "headers", za0001, "key")
return
}
case "value":
z.headers[za0001].value, err = dc.ReadBytesLimit(z.headers[za0001].value, 16384)
if err == nil && z.headers[za0001].value == nil {
z.headers[za0001].value = []byte{}
}
if err != nil {
err = msgp.WrapError(err, "headers", za0001, "value")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err, "headers", za0001)
return
}
}
}
}
case "body":
z.body, err = dc.ReadBytes(z.body)
if err != nil {
err = msgp.WrapError(err, "body")
return
}
case "ctype":
z.ctype, err = dc.ReadBytesLimit(z.ctype, 256)
if err == nil && z.ctype == nil {
z.ctype = []byte{}
}
if err != nil {
err = msgp.WrapError(err, "ctype")
return
}
case "cencoding":
z.cencoding, err = dc.ReadBytesLimit(z.cencoding, 128)
if err == nil && z.cencoding == nil {
z.cencoding = []byte{}
}
if err != nil {
err = msgp.WrapError(err, "cencoding")
return
}
case "cacheControl":
z.cacheControl, err = dc.ReadBytesLimit(z.cacheControl, 2048)
if err == nil && z.cacheControl == nil {
z.cacheControl = []byte{}
}
if err != nil {
err = msgp.WrapError(err, "cacheControl")
return
}
case "expires":
z.expires, err = dc.ReadBytesLimit(z.expires, 128)
if err == nil && z.expires == nil {
z.expires = []byte{}
}
if err != nil {
err = msgp.WrapError(err, "expires")
return
}
case "etag":
z.etag, err = dc.ReadBytesLimit(z.etag, 256)
if err == nil && z.etag == nil {
z.etag = []byte{}
}
if err != nil {
err = msgp.WrapError(err, "etag")
return
}
case "date":
z.date, err = dc.ReadUint64()
if err != nil {
err = msgp.WrapError(err, "date")
return
}
case "status":
z.status, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "status")
return
}
case "age":
z.age, err = dc.ReadUint64()
if err != nil {
err = msgp.WrapError(err, "age")
return
}
case "exp":
z.exp, err = dc.ReadUint64()
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
case "ttl":
z.ttl, err = dc.ReadUint64()
if err != nil {
err = msgp.WrapError(err, "ttl")
return
}
case "forceRevalidate":
z.forceRevalidate, err = dc.ReadBool()
if err != nil {
err = msgp.WrapError(err, "forceRevalidate")
return
}
case "revalidate":
z.revalidate, err = dc.ReadBool()
if err != nil {
err = msgp.WrapError(err, "revalidate")
return
}
case "shareable":
z.shareable, err = dc.ReadBool()
if err != nil {
err = msgp.WrapError(err, "shareable")
return
}
case "private":
z.private, err = dc.ReadBool()
if err != nil {
err = msgp.WrapError(err, "private")
return
}
case "heapidx":
z.heapidx, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "heapidx")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z *item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 17
// write "headers"
err = en.Append(0xde, 0x0, 0x11, 0xa7, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73)
if err != nil {
return
}
err = en.WriteArrayHeader(uint32(len(z.headers)))
if err != nil {
err = msgp.WrapError(err, "headers")
return
}
for za0001 := range z.headers {
// map header, size 2
// write "key"
err = en.Append(0x82, 0xa3, 0x6b, 0x65, 0x79)
if err != nil {
return
}
err = en.WriteBytes(z.headers[za0001].key)
if err != nil {
err = msgp.WrapError(err, "headers", za0001, "key")
return
}
// write "value"
err = en.Append(0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65)
if err != nil {
return
}
err = en.WriteBytes(z.headers[za0001].value)
if err != nil {
err = msgp.WrapError(err, "headers", za0001, "value")
return
}
}
// write "body"
err = en.Append(0xa4, 0x62, 0x6f, 0x64, 0x79)
if err != nil {
return
}
err = en.WriteBytes(z.body)
if err != nil {
err = msgp.WrapError(err, "body")
return
}
// write "ctype"
err = en.Append(0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
if err != nil {
return
}
err = en.WriteBytes(z.ctype)
if err != nil {
err = msgp.WrapError(err, "ctype")
return
}
// write "cencoding"
err = en.Append(0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67)
if err != nil {
return
}
err = en.WriteBytes(z.cencoding)
if err != nil {
err = msgp.WrapError(err, "cencoding")
return
}
// write "cacheControl"
err = en.Append(0xac, 0x63, 0x61, 0x63, 0x68, 0x65, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c)
if err != nil {
return
}
err = en.WriteBytes(z.cacheControl)
if err != nil {
err = msgp.WrapError(err, "cacheControl")
return
}
// write "expires"
err = en.Append(0xa7, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73)
if err != nil {
return
}
err = en.WriteBytes(z.expires)
if err != nil {
err = msgp.WrapError(err, "expires")
return
}
// write "etag"
err = en.Append(0xa4, 0x65, 0x74, 0x61, 0x67)
if err != nil {
return
}
err = en.WriteBytes(z.etag)
if err != nil {
err = msgp.WrapError(err, "etag")
return
}
// write "date"
err = en.Append(0xa4, 0x64, 0x61, 0x74, 0x65)
if err != nil {
return
}
err = en.WriteUint64(z.date)
if err != nil {
err = msgp.WrapError(err, "date")
return
}
// write "status"
err = en.Append(0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
if err != nil {
return
}
err = en.WriteInt(z.status)
if err != nil {
err = msgp.WrapError(err, "status")
return
}
// write "age"
err = en.Append(0xa3, 0x61, 0x67, 0x65)
if err != nil {
return
}
err = en.WriteUint64(z.age)
if err != nil {
err = msgp.WrapError(err, "age")
return
}
// write "exp"
err = en.Append(0xa3, 0x65, 0x78, 0x70)
if err != nil {
return
}
err = en.WriteUint64(z.exp)
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
// write "ttl"
err = en.Append(0xa3, 0x74, 0x74, 0x6c)
if err != nil {
return
}
err = en.WriteUint64(z.ttl)
if err != nil {
err = msgp.WrapError(err, "ttl")
return
}
// write "forceRevalidate"
err = en.Append(0xaf, 0x66, 0x6f, 0x72, 0x63, 0x65, 0x52, 0x65, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65)
if err != nil {
return
}
err = en.WriteBool(z.forceRevalidate)
if err != nil {
err = msgp.WrapError(err, "forceRevalidate")
return
}
// write "revalidate"
err = en.Append(0xaa, 0x72, 0x65, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65)
if err != nil {
return
}
err = en.WriteBool(z.revalidate)
if err != nil {
err = msgp.WrapError(err, "revalidate")
return
}
// write "shareable"
err = en.Append(0xa9, 0x73, 0x68, 0x61, 0x72, 0x65, 0x61, 0x62, 0x6c, 0x65)
if err != nil {
return
}
err = en.WriteBool(z.shareable)
if err != nil {
err = msgp.WrapError(err, "shareable")
return
}
// write "private"
err = en.Append(0xa7, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65)
if err != nil {
return
}
err = en.WriteBool(z.private)
if err != nil {
err = msgp.WrapError(err, "private")
return
}
// write "heapidx"
err = en.Append(0xa7, 0x68, 0x65, 0x61, 0x70, 0x69, 0x64, 0x78)
if err != nil {
return
}
err = en.WriteInt(z.heapidx)
if err != nil {
err = msgp.WrapError(err, "heapidx")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z *item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 17
// string "headers"
o = append(o, 0xde, 0x0, 0x11, 0xa7, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73)
o = msgp.AppendArrayHeader(o, uint32(len(z.headers)))
for za0001 := range z.headers {
// map header, size 2
// string "key"
o = append(o, 0x82, 0xa3, 0x6b, 0x65, 0x79)
o = msgp.AppendBytes(o, z.headers[za0001].key)
// string "value"
o = append(o, 0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65)
o = msgp.AppendBytes(o, z.headers[za0001].value)
}
// string "body"
o = append(o, 0xa4, 0x62, 0x6f, 0x64, 0x79)
o = msgp.AppendBytes(o, z.body)
// string "ctype"
o = append(o, 0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
o = msgp.AppendBytes(o, z.ctype)
// string "cencoding"
o = append(o, 0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67)
o = msgp.AppendBytes(o, z.cencoding)
// string "cacheControl"
o = append(o, 0xac, 0x63, 0x61, 0x63, 0x68, 0x65, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c)
o = msgp.AppendBytes(o, z.cacheControl)
// string "expires"
o = append(o, 0xa7, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73)
o = msgp.AppendBytes(o, z.expires)
// string "etag"
o = append(o, 0xa4, 0x65, 0x74, 0x61, 0x67)
o = msgp.AppendBytes(o, z.etag)
// string "date"
o = append(o, 0xa4, 0x64, 0x61, 0x74, 0x65)
o = msgp.AppendUint64(o, z.date)
// string "status"
o = append(o, 0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73)
o = msgp.AppendInt(o, z.status)
// string "age"
o = append(o, 0xa3, 0x61, 0x67, 0x65)
o = msgp.AppendUint64(o, z.age)
// string "exp"
o = append(o, 0xa3, 0x65, 0x78, 0x70)
o = msgp.AppendUint64(o, z.exp)
// string "ttl"
o = append(o, 0xa3, 0x74, 0x74, 0x6c)
o = msgp.AppendUint64(o, z.ttl)
// string "forceRevalidate"
o = append(o, 0xaf, 0x66, 0x6f, 0x72, 0x63, 0x65, 0x52, 0x65, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65)
o = msgp.AppendBool(o, z.forceRevalidate)
// string "revalidate"
o = append(o, 0xaa, 0x72, 0x65, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65)
o = msgp.AppendBool(o, z.revalidate)
// string "shareable"
o = append(o, 0xa9, 0x73, 0x68, 0x61, 0x72, 0x65, 0x61, 0x62, 0x6c, 0x65)
o = msgp.AppendBool(o, z.shareable)
// string "private"
o = append(o, 0xa7, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65)
o = msgp.AppendBool(o, z.private)
// string "heapidx"
o = append(o, 0xa7, 0x68, 0x65, 0x61, 0x70, 0x69, 0x64, 0x78)
o = msgp.AppendInt(o, z.heapidx)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "headers":
var zb0002 uint32
zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "headers")
return
}
if zb0002 > 1024 {
err = msgp.ErrLimitExceeded
return
}
if cap(z.headers) >= int(zb0002) {
z.headers = (z.headers)[:zb0002]
} else {
z.headers = make([]cachedHeader, zb0002)
}
for za0001 := range z.headers {
var zb0003 uint32
zb0003, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "headers", za0001)
return
}
if zb0003 > 1024 {
err = msgp.ErrLimitExceeded
return
}
for zb0003 > 0 {
zb0003--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err, "headers", za0001)
return
}
switch msgp.UnsafeString(field) {
case "key":
var zb0004 uint32
zb0004, bts, err = msgp.ReadBytesHeader(bts)
if err != nil {
err = msgp.WrapError(err, "headers", za0001, "key")
return
}
if zb0004 > 512 {
err = msgp.ErrLimitExceeded
return
}
if z.headers[za0001].key == nil || uint32(cap(z.headers[za0001].key)) < zb0004 {
z.headers[za0001].key = make([]byte, zb0004)
} else {
z.headers[za0001].key = z.headers[za0001].key[:zb0004]
}
if uint32(len(bts)) < zb0004 {
err = msgp.ErrShortBytes
return
}
copy(z.headers[za0001].key, bts[:zb0004])
bts = bts[zb0004:]
case "value":
var zb0005 uint32
zb0005, bts, err = msgp.ReadBytesHeader(bts)
if err != nil {
err = msgp.WrapError(err, "headers", za0001, "value")
return
}
if zb0005 > 16384 {
err = msgp.ErrLimitExceeded
return
}
if z.headers[za0001].value == nil || uint32(cap(z.headers[za0001].value)) < zb0005 {
z.headers[za0001].value = make([]byte, zb0005)
} else {
z.headers[za0001].value = z.headers[za0001].value[:zb0005]
}
if uint32(len(bts)) < zb0005 {
err = msgp.ErrShortBytes
return
}
copy(z.headers[za0001].value, bts[:zb0005])
bts = bts[zb0005:]
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err, "headers", za0001)
return
}
}
}
}
case "body":
z.body, bts, err = msgp.ReadBytesBytes(bts, z.body)
if err != nil {
err = msgp.WrapError(err, "body")
return
}
case "ctype":
var zb0006 uint32
zb0006, bts, err = msgp.ReadBytesHeader(bts)
if err != nil {
err = msgp.WrapError(err, "ctype")
return
}
if zb0006 > 256 {
err = msgp.ErrLimitExceeded
return
}
if z.ctype == nil || uint32(cap(z.ctype)) < zb0006 {
z.ctype = make([]byte, zb0006)
} else {
z.ctype = z.ctype[:zb0006]
}
if uint32(len(bts)) < zb0006 {
err = msgp.ErrShortBytes
return
}
copy(z.ctype, bts[:zb0006])
bts = bts[zb0006:]
case "cencoding":
var zb0007 uint32
zb0007, bts, err = msgp.ReadBytesHeader(bts)
if err != nil {
err = msgp.WrapError(err, "cencoding")
return
}
if zb0007 > 128 {
err = msgp.ErrLimitExceeded
return
}
if z.cencoding == nil || uint32(cap(z.cencoding)) < zb0007 {
z.cencoding = make([]byte, zb0007)
} else {
z.cencoding = z.cencoding[:zb0007]
}
if uint32(len(bts)) < zb0007 {
err = msgp.ErrShortBytes
return
}
copy(z.cencoding, bts[:zb0007])
bts = bts[zb0007:]
case "cacheControl":
var zb0008 uint32
zb0008, bts, err = msgp.ReadBytesHeader(bts)
if err != nil {
err = msgp.WrapError(err, "cacheControl")
return
}
if zb0008 > 2048 {
err = msgp.ErrLimitExceeded
return
}
if z.cacheControl == nil || uint32(cap(z.cacheControl)) < zb0008 {
z.cacheControl = make([]byte, zb0008)
} else {
z.cacheControl = z.cacheControl[:zb0008]
}
if uint32(len(bts)) < zb0008 {
err = msgp.ErrShortBytes
return
}
copy(z.cacheControl, bts[:zb0008])
bts = bts[zb0008:]
case "expires":
var zb0009 uint32
zb0009, bts, err = msgp.ReadBytesHeader(bts)
if err != nil {
err = msgp.WrapError(err, "expires")
return
}
if zb0009 > 128 {
err = msgp.ErrLimitExceeded
return
}
if z.expires == nil || uint32(cap(z.expires)) < zb0009 {
z.expires = make([]byte, zb0009)
} else {
z.expires = z.expires[:zb0009]
}
if uint32(len(bts)) < zb0009 {
err = msgp.ErrShortBytes
return
}
copy(z.expires, bts[:zb0009])
bts = bts[zb0009:]
case "etag":
var zb0010 uint32
zb0010, bts, err = msgp.ReadBytesHeader(bts)
if err != nil {
err = msgp.WrapError(err, "etag")
return
}
if zb0010 > 256 {
err = msgp.ErrLimitExceeded
return
}
if z.etag == nil || uint32(cap(z.etag)) < zb0010 {
z.etag = make([]byte, zb0010)
} else {
z.etag = z.etag[:zb0010]
}
if uint32(len(bts)) < zb0010 {
err = msgp.ErrShortBytes
return
}
copy(z.etag, bts[:zb0010])
bts = bts[zb0010:]
case "date":
z.date, bts, err = msgp.ReadUint64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "date")
return
}
case "status":
z.status, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "status")
return
}
case "age":
z.age, bts, err = msgp.ReadUint64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "age")
return
}
case "exp":
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
case "ttl":
z.ttl, bts, err = msgp.ReadUint64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "ttl")
return
}
case "forceRevalidate":
z.forceRevalidate, bts, err = msgp.ReadBoolBytes(bts)
if err != nil {
err = msgp.WrapError(err, "forceRevalidate")
return
}
case "revalidate":
z.revalidate, bts, err = msgp.ReadBoolBytes(bts)
if err != nil {
err = msgp.WrapError(err, "revalidate")
return
}
case "shareable":
z.shareable, bts, err = msgp.ReadBoolBytes(bts)
if err != nil {
err = msgp.WrapError(err, "shareable")
return
}
case "private":
z.private, bts, err = msgp.ReadBoolBytes(bts)
if err != nil {
err = msgp.WrapError(err, "private")
return
}
case "heapidx":
z.heapidx, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "heapidx")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *item) Msgsize() (s int) {
s = 3 + 8 + msgp.ArrayHeaderSize
for za0001 := range z.headers {
s += 1 + 4 + msgp.BytesPrefixSize + len(z.headers[za0001].key) + 6 + msgp.BytesPrefixSize + len(z.headers[za0001].value)
}
s += 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 13 + msgp.BytesPrefixSize + len(z.cacheControl) + 8 + msgp.BytesPrefixSize + len(z.expires) + 5 + msgp.BytesPrefixSize + len(z.etag) + 5 + msgp.Uint64Size + 7 + msgp.IntSize + 4 + msgp.Uint64Size + 4 + msgp.Uint64Size + 4 + msgp.Uint64Size + 16 + msgp.BoolSize + 11 + msgp.BoolSize + 10 + msgp.BoolSize + 8 + msgp.BoolSize + 8 + msgp.IntSize
return
}
================================================
FILE: middleware/cache/manager_msgp_test.go
================================================
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
package cache
import (
"bytes"
"testing"
"github.com/tinylib/msgp/msgp"
)
func TestMarshalUnmarshalcachedHeader(t *testing.T) {
v := cachedHeader{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgcachedHeader(b *testing.B) {
v := cachedHeader{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgcachedHeader(b *testing.B) {
v := cachedHeader{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshalcachedHeader(b *testing.B) {
v := cachedHeader{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}
func TestEncodeDecodecachedHeader(t *testing.T) {
v := cachedHeader{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
m := v.Msgsize()
if buf.Len() > m {
t.Log("WARNING: TestEncodeDecodecachedHeader Msgsize() is inaccurate")
}
vn := cachedHeader{}
err := msgp.Decode(&buf, &vn)
if err != nil {
t.Error(err)
}
buf.Reset()
msgp.Encode(&buf, &v)
err = msgp.NewReader(&buf).Skip()
if err != nil {
t.Error(err)
}
}
func BenchmarkEncodecachedHeader(b *testing.B) {
v := cachedHeader{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
en := msgp.NewWriter(msgp.Nowhere)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.EncodeMsg(en)
}
en.Flush()
}
func BenchmarkDecodecachedHeader(b *testing.B) {
v := cachedHeader{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
rd := msgp.NewEndlessReader(buf.Bytes(), b)
dc := msgp.NewReader(rd)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := v.DecodeMsg(dc)
if err != nil {
b.Fatal(err)
}
}
}
func TestMarshalUnmarshalitem(t *testing.T) {
v := item{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgitem(b *testing.B) {
v := item{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgitem(b *testing.B) {
v := item{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshalitem(b *testing.B) {
v := item{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}
func TestEncodeDecodeitem(t *testing.T) {
v := item{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
m := v.Msgsize()
if buf.Len() > m {
t.Log("WARNING: TestEncodeDecodeitem Msgsize() is inaccurate")
}
vn := item{}
err := msgp.Decode(&buf, &vn)
if err != nil {
t.Error(err)
}
buf.Reset()
msgp.Encode(&buf, &v)
err = msgp.NewReader(&buf).Skip()
if err != nil {
t.Error(err)
}
}
func BenchmarkEncodeitem(b *testing.B) {
v := item{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
en := msgp.NewWriter(msgp.Nowhere)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.EncodeMsg(en)
}
en.Flush()
}
func BenchmarkDecodeitem(b *testing.B) {
v := item{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
rd := msgp.NewEndlessReader(buf.Bytes(), b)
dc := msgp.NewReader(rd)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := v.DecodeMsg(dc)
if err != nil {
b.Fatal(err)
}
}
}
================================================
FILE: middleware/cache/manager_test.go
================================================
package cache
import (
"context"
"testing"
"time"
"github.com/gofiber/utils/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_manager_get(t *testing.T) {
t.Parallel()
cacheManager := newManager(nil, true)
t.Run("Item not found in cache", func(t *testing.T) {
t.Parallel()
it, err := cacheManager.get(context.Background(), utils.UUIDv4())
require.ErrorIs(t, err, errCacheMiss)
assert.Nil(t, it)
})
t.Run("Item found in cache", func(t *testing.T) {
t.Parallel()
id := utils.UUIDv4()
cacheItem := cacheManager.acquire()
cacheItem.body = []byte("test-body")
require.NoError(t, cacheManager.set(context.Background(), id, cacheItem, 10*time.Second))
it, err := cacheManager.get(context.Background(), id)
require.NoError(t, err)
assert.NotNil(t, it)
})
}
func Test_manager_logKey(t *testing.T) {
t.Parallel()
redactedManager := newManager(nil, true)
assert.Equal(t, redactedKey, redactedManager.logKey("secret"))
plainManager := newManager(nil, false)
assert.Equal(t, "secret", plainManager.logKey("secret"))
}
================================================
FILE: middleware/compress/compress.go
================================================
package compress
import (
"strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/etag"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
func hasToken(header, token string) bool {
for part := range strings.SplitSeq(header, ",") {
if utils.EqualFold(utils.TrimSpace(part), token) {
return true
}
}
return false
}
func shouldSkip(c fiber.Ctx) bool {
if c.Method() == fiber.MethodHead {
return true
}
status := c.Response().StatusCode()
if status < 200 ||
status == fiber.StatusNoContent ||
status == fiber.StatusResetContent ||
status == fiber.StatusNotModified ||
status == fiber.StatusPartialContent ||
len(c.Response().Body()) == 0 ||
c.Get(fiber.HeaderRange) != "" ||
hasToken(c.Get(fiber.HeaderCacheControl), "no-transform") ||
hasToken(c.GetRespHeader(fiber.HeaderCacheControl), "no-transform") {
return true
}
return false
}
func appendVaryAcceptEncoding(c fiber.Ctx) {
vary := c.GetRespHeader(fiber.HeaderVary)
if vary == "" {
c.Set(fiber.HeaderVary, fiber.HeaderAcceptEncoding)
return
}
if hasToken(vary, "*") || hasToken(vary, fiber.HeaderAcceptEncoding) {
return
}
c.Set(fiber.HeaderVary, vary+", "+fiber.HeaderAcceptEncoding)
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Setup request handlers
var (
fctx = func(_ *fasthttp.RequestCtx) {}
compressor fasthttp.RequestHandler
)
// Setup compression algorithm
switch cfg.Level {
case LevelDefault:
// LevelDefault
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
fasthttp.CompressBrotliDefaultCompression,
fasthttp.CompressDefaultCompression,
)
case LevelBestSpeed:
// LevelBestSpeed
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
fasthttp.CompressBrotliBestSpeed,
fasthttp.CompressBestSpeed,
)
case LevelBestCompression:
// LevelBestCompression
compressor = fasthttp.CompressHandlerBrotliLevel(fctx,
fasthttp.CompressBrotliBestCompression,
fasthttp.CompressBestCompression,
)
default:
// LevelDisabled
return func(c fiber.Ctx) error {
return c.Next()
}
}
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Continue stack
if err := c.Next(); err != nil {
return err
}
if shouldSkip(c) {
appendVaryAcceptEncoding(c)
return nil
}
if c.GetRespHeader(fiber.HeaderContentEncoding) != "" {
appendVaryAcceptEncoding(c)
return nil
}
compressor(c.RequestCtx())
if tag := c.GetRespHeader(fiber.HeaderETag); tag != "" && !strings.HasPrefix(tag, "W/") {
if c.GetRespHeader(fiber.HeaderContentEncoding) != "" {
c.Set(fiber.HeaderETag, string(etag.Generate(c.Response().Body())))
}
}
appendVaryAcceptEncoding(c)
return nil
}
}
================================================
FILE: middleware/compress/compress_test.go
================================================
package compress
import (
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/etag"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
var filedata []byte
var testConfig = fiber.TestConfig{
Timeout: 10 * time.Second,
FailOnTimeout: true,
}
func init() {
dat, err := os.ReadFile("../../.github/README.md")
if err != nil {
panic(err)
}
filedata = dat
}
// go test -run Test_Compress_Gzip
func Test_Compress_Gzip(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding))
// Validate that the file size has shrunk
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Less(t, len(body), len(filedata))
}
// go test -run Test_Compress_Different_Level
func Test_Compress_Different_Level(t *testing.T) {
t.Parallel()
levels := []Level{LevelDefault, LevelBestSpeed, LevelBestCompression}
algorithms := []string{"gzip", "deflate", "br", "zstd"}
for _, algo := range algorithms {
for _, level := range levels {
t.Run(fmt.Sprintf("%s_level %d", algo, level), func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Level: level}))
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", algo)
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, algo, resp.Header.Get(fiber.HeaderContentEncoding))
// Validate that the file size has shrunk
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Less(t, len(body), len(filedata))
})
}
}
}
func Test_Compress_Deflate(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "deflate")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, "deflate", resp.Header.Get(fiber.HeaderContentEncoding))
// Validate that the file size has shrunk
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Less(t, len(body), len(filedata))
}
func Test_Compress_Brotli(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "br")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, "br", resp.Header.Get(fiber.HeaderContentEncoding))
// Validate that the file size has shrunk
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Less(t, len(body), len(filedata))
}
func Test_Compress_Zstd(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "zstd")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, "zstd", resp.Header.Get(fiber.HeaderContentEncoding))
// Validate that the file size has shrunk
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Less(t, len(body), len(filedata))
}
func Test_Compress_Disabled(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Level: LevelDisabled}))
app.Get("/", func(c fiber.Ctx) error {
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "br")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
// Validate the file size is not shrunk
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Len(t, body, len(filedata))
}
func Test_Compress_Adds_Vary_Header(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "Accept-Encoding", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Vary_Star(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderVary, "*")
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "*", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Vary_List_Star(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderVary, "User-Agent, *")
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "User-Agent, *", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Vary_Similar_Substring(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderVary, "Accept-Encoding2")
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "Accept-Encoding2, Accept-Encoding", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Skip_When_Content_Encoding_Set(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderContentEncoding, "gzip")
c.Set(fiber.HeaderETag, "\"abc\"")
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "hello", string(body))
require.Equal(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding))
require.Equal(t, "\"abc\"", resp.Header.Get(fiber.HeaderETag))
require.Equal(t, "Accept-Encoding", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Skip_When_Content_Encoding_Set_Vary_Star(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderContentEncoding, "gzip")
c.Set(fiber.HeaderVary, "*")
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "*", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Skip_When_Content_Encoding_Set_Vary_List_Star(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderContentEncoding, "gzip")
c.Set(fiber.HeaderVary, "User-Agent, *")
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "User-Agent, *", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Skip_When_Content_Encoding_Set_Vary_Similar_Substring(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderContentEncoding, "gzip")
c.Set(fiber.HeaderVary, "Accept-Encoding2")
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "Accept-Encoding2, Accept-Encoding", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Strong_ETag_Recalculated(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
c.Set(fiber.HeaderETag, "\"abc\"")
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
expected := string(etag.Generate(body))
require.Equal(t, expected, resp.Header.Get(fiber.HeaderETag))
}
func Test_Compress_Weak_ETag_Unchanged(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
c.Set(fiber.HeaderETag, "W/\"abc\"")
return c.Send(filedata)
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding))
require.Equal(t, "W/\"abc\"", resp.Header.Get(fiber.HeaderETag))
}
func Test_Compress_Strong_ETag_Unchanged_When_Not_Compressed(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderETag, "\"abc\"")
return c.SendString("tiny")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err)
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
require.Equal(t, "\"abc\"", resp.Header.Get(fiber.HeaderETag))
require.Equal(t, "Accept-Encoding", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Skip_Head(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
handler := func(c fiber.Ctx) error {
c.Set(fiber.HeaderETag, "\"abc\"")
return c.Send(filedata)
}
app.Get("/", handler)
app.Head("/", handler)
getReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
getReq.Header.Set("Accept-Encoding", "gzip")
getResp, err := app.Test(getReq, testConfig)
require.NoError(t, err, "app.Test(getReq)")
getBody, err := io.ReadAll(getResp.Body)
require.NoError(t, err)
require.NotEmpty(t, getBody)
require.Equal(t, "gzip", getResp.Header.Get(fiber.HeaderContentEncoding))
headReq := httptest.NewRequest(fiber.MethodHead, "/", http.NoBody)
headReq.Header.Set("Accept-Encoding", "gzip")
headResp, err := app.Test(headReq, testConfig)
require.NoError(t, err, "app.Test(headReq)")
headBody, err := io.ReadAll(headResp.Body)
require.NoError(t, err)
require.Empty(t, headBody)
require.Empty(t, headResp.Header.Get(fiber.HeaderContentEncoding))
require.Equal(t, "Accept-Encoding", headResp.Header.Get(fiber.HeaderVary))
require.Equal(t, "\"abc\"", headResp.Header.Get(fiber.HeaderETag))
}
func Test_Compress_Skip_Status_NoContent(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderETag, "\"abc\"")
return c.SendStatus(fiber.StatusNoContent)
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusNoContent, resp.StatusCode)
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
require.Equal(t, "\"abc\"", resp.Header.Get(fiber.HeaderETag))
}
func Test_Compress_Skip_Status_NotModified(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderETag, "\"abc\"")
c.Status(fiber.StatusNotModified)
return nil
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusNotModified, resp.StatusCode)
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
require.Equal(t, "\"abc\"", resp.Header.Get(fiber.HeaderETag))
}
func Test_Compress_Skip_Range(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
req.Header.Set("Range", "bytes=0-1")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
require.Equal(t, "Accept-Encoding", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Skip_Range_NoAcceptEncoding(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Range", "bytes=0-1")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
require.Equal(t, "Accept-Encoding", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Skip_Range_Vary_Star(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderVary, "*")
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
req.Header.Set("Range", "bytes=0-1")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
require.Equal(t, "*", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Skip_Range_Vary_Similar_Substring(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderVary, "Accept-Encoding2")
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
req.Header.Set("Range", "bytes=0-1")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
require.Equal(t, "Accept-Encoding2, Accept-Encoding", resp.Header.Get(fiber.HeaderVary))
}
func Test_Compress_Skip_Status_PartialContent(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Status(fiber.StatusPartialContent)
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusPartialContent, resp.StatusCode)
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
}
func Test_Compress_Skip_NoTransform(t *testing.T) {
t.Parallel()
tests := []struct {
name string
setRequest bool
}{
{name: "request", setRequest: true},
{name: "response", setRequest: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
if !tt.setRequest {
c.Set(fiber.HeaderCacheControl, "no-transform")
}
return c.SendString("hello")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
if tt.setRequest {
req.Header.Set(fiber.HeaderCacheControl, "no-transform")
}
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
})
}
}
func Test_Compress_Next_Error(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(_ fiber.Ctx) error {
return errors.New("next error")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("Accept-Encoding", "gzip")
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 500, resp.StatusCode, "Status code")
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "next error", string(body))
}
// go test -run Test_Compress_Next
func Test_Compress_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
// go test -bench=Benchmark_Compress
func Benchmark_Compress(b *testing.B) {
tests := []struct {
name string
acceptEncoding string
}{
{name: "Gzip", acceptEncoding: "gzip"},
{name: "Deflate", acceptEncoding: "deflate"},
{name: "Brotli", acceptEncoding: "br"},
{name: "Zstd", acceptEncoding: "zstd"},
}
for _, tt := range tests {
b.Run(tt.name, func(b *testing.B) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
return c.Send(filedata)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
if tt.acceptEncoding != "" {
fctx.Request.Header.Set("Accept-Encoding", tt.acceptEncoding)
}
b.ReportAllocs()
for b.Loop() {
h(fctx)
}
})
}
}
// go test -bench=Benchmark_Compress_Levels
func Benchmark_Compress_Levels(b *testing.B) {
tests := []struct {
name string
acceptEncoding string
}{
{name: "Gzip", acceptEncoding: "gzip"},
{name: "Deflate", acceptEncoding: "deflate"},
{name: "Brotli", acceptEncoding: "br"},
{name: "Zstd", acceptEncoding: "zstd"},
}
levels := []struct {
name string
level Level
}{
{name: "LevelDisabled", level: LevelDisabled},
{name: "LevelDefault", level: LevelDefault},
{name: "LevelBestSpeed", level: LevelBestSpeed},
{name: "LevelBestCompression", level: LevelBestCompression},
}
for _, tt := range tests {
for _, lvl := range levels {
b.Run(tt.name+"_"+lvl.name, func(b *testing.B) {
app := fiber.New()
app.Use(New(Config{Level: lvl.level}))
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
return c.Send(filedata)
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
if tt.acceptEncoding != "" {
fctx.Request.Header.Set("Accept-Encoding", tt.acceptEncoding)
}
b.ReportAllocs()
for b.Loop() {
h(fctx)
}
})
}
}
}
// go test -bench=Benchmark_Compress_Parallel
func Benchmark_Compress_Parallel(b *testing.B) {
tests := []struct {
name string
acceptEncoding string
}{
{name: "Gzip", acceptEncoding: "gzip"},
{name: "Deflate", acceptEncoding: "deflate"},
{name: "Brotli", acceptEncoding: "br"},
{name: "Zstd", acceptEncoding: "zstd"},
}
for _, tt := range tests {
b.Run(tt.name, func(b *testing.B) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
return c.Send(filedata)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
if tt.acceptEncoding != "" {
fctx.Request.Header.Set("Accept-Encoding", tt.acceptEncoding)
}
for pb.Next() {
h(fctx)
}
})
})
}
}
// go test -bench=Benchmark_Compress_Levels_Parallel
func Benchmark_Compress_Levels_Parallel(b *testing.B) {
tests := []struct {
name string
acceptEncoding string
}{
{name: "Gzip", acceptEncoding: "gzip"},
{name: "Deflate", acceptEncoding: "deflate"},
{name: "Brotli", acceptEncoding: "br"},
{name: "Zstd", acceptEncoding: "zstd"},
}
levels := []struct {
name string
level Level
}{
{name: "LevelDisabled", level: LevelDisabled},
{name: "LevelDefault", level: LevelDefault},
{name: "LevelBestSpeed", level: LevelBestSpeed},
{name: "LevelBestCompression", level: LevelBestCompression},
}
for _, tt := range tests {
for _, lvl := range levels {
b.Run(tt.name+"_"+lvl.name, func(b *testing.B) {
app := fiber.New()
app.Use(New(Config{Level: lvl.level}))
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8)
return c.Send(filedata)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
if tt.acceptEncoding != "" {
fctx.Request.Header.Set("Accept-Encoding", tt.acceptEncoding)
}
for pb.Next() {
h(fctx)
}
})
})
}
}
}
================================================
FILE: middleware/compress/config.go
================================================
package compress
import (
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// Level sets the compression level used for the response
//
// Optional. Default: LevelDefault
// LevelDisabled: -1
// LevelDefault: 0
// LevelBestSpeed: 1
// LevelBestCompression: 2
Level Level
}
// Level is numeric representation of compression level
type Level int
// Represents compression level that will be used in the middleware
const (
LevelDisabled Level = -1
LevelDefault Level = 0
LevelBestSpeed Level = 1
LevelBestCompression Level = 2
)
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Level: LevelDefault,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression {
cfg.Level = ConfigDefault.Level
}
return cfg
}
================================================
FILE: middleware/cors/config.go
================================================
package cors
import (
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// AllowOriginsFunc defines a function that will set the 'Access-Control-Allow-Origin'
// response header to the 'origin' request header when returned true. This allows for
// dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins
// will be not have the 'Access-Control-Allow-Credentials' header set to 'true'.
//
// The function receives serialized origins (scheme + host) or the literal "null" string.
// According to the CORS specification (https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS#origin),
// browsers send "null" for privacy-sensitive contexts like sandboxed iframes or file:// URLs.
//
// Origins with userinfo, paths, queries, fragments, or wildcards are rejected and will not
// be passed to this function.
//
// Optional. Default: nil
AllowOriginsFunc func(origin string) bool
// AllowOrigin defines a list of origins that may access the resource.
//
// This supports wildcard matching for subdomains by prefixing the domain with a `*.`
// e.g. "http://.domain.com". This will allow all level of subdomains of domain.com to access the resource.
//
// If the special wildcard `"*"` is present in the list, all origins will be allowed.
//
// Optional. Default value []string{}
AllowOrigins []string
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
//
// Optional. Default value []string{"GET", "POST", "HEAD", "PUT", "DELETE", "PATCH"}
AllowMethods []string
// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request.
//
// Optional. Default value []string{}
AllowHeaders []string
// ExposeHeaders defines an allowlist of headers that clients are allowed to
// access.
//
// Optional. Default value []string{}.
ExposeHeaders []string
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
// If you pass MaxAge 0, Access-Control-Max-Age header will not be added and
// browser will use 5 seconds by default.
// To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0.
//
// Optional. Default value 0.
MaxAge int
// DisableValueRedaction turns off redaction of configuration values and origins in logs and panics.
//
// Optional. Default: false
DisableValueRedaction bool
// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials. Note: if true, the
// AllowOrigins setting cannot contain the wildcard "*" to prevent
// security vulnerabilities.
//
// Optional. Default value false.
AllowCredentials bool
// AllowPrivateNetwork indicates whether the Access-Control-Allow-Private-Network
// response header should be set to true, allowing requests from private networks.
//
// Optional. Default value false.
AllowPrivateNetwork bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
AllowOriginsFunc: nil,
AllowOrigins: []string{"*"},
DisableValueRedaction: false,
AllowMethods: []string{
fiber.MethodGet,
fiber.MethodPost,
fiber.MethodHead,
fiber.MethodPut,
fiber.MethodDelete,
fiber.MethodPatch,
},
AllowHeaders: []string{},
AllowCredentials: false,
ExposeHeaders: []string{},
MaxAge: 0,
AllowPrivateNetwork: false,
}
================================================
FILE: middleware/cors/cors.go
================================================
package cors
import (
"slices"
"strconv"
"strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/utils/v2"
utilsstrings "github.com/gofiber/utils/v2/strings"
)
const redactedValue = "[redacted]"
// isOriginSerializedOrNull checks if the origin is a serialized origin or the literal "null".
// It returns two booleans: (isSerialized, isNull).
func isOriginSerializedOrNull(originHeaderRaw string) (isSerialized, isNull bool) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming serialization and null status results
if originHeaderRaw == "null" {
return false, true
}
originIsSerialized, _ := normalizeOrigin(originHeaderRaw)
return originIsSerialized, false
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if len(cfg.AllowMethods) == 0 {
cfg.AllowMethods = ConfigDefault.AllowMethods
}
}
redactValues := !cfg.DisableValueRedaction
maskValue := func(value string) string {
if redactValues {
return redactedValue
}
return value
}
// Warning logs if both AllowOrigins and AllowOriginsFunc are set
if len(cfg.AllowOrigins) > 0 && cfg.AllowOriginsFunc != nil {
log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.")
}
// allowOrigins is a slice of strings that contains the allowed origins
// defined in the 'AllowOrigins' configuration.
allowOrigins := []string{}
allowSubOrigins := []subdomain{}
// Validate and normalize static AllowOrigins
allowAllOrigins := len(cfg.AllowOrigins) == 0 && cfg.AllowOriginsFunc == nil
for _, origin := range cfg.AllowOrigins {
if origin == "*" {
allowAllOrigins = true
break
}
trimmedOrigin := utils.TrimSpace(origin)
if before, after, found := strings.Cut(trimmedOrigin, "://*."); found {
withoutWildcard := before + "://" + after
isValid, normalizedOrigin := normalizeOrigin(withoutWildcard)
if !isValid {
panic("[CORS] Invalid origin format in configuration: " + maskValue(trimmedOrigin))
}
scheme, host, ok := strings.Cut(normalizedOrigin, "://")
if !ok {
panic("[CORS] Invalid origin format after normalization:" + maskValue(trimmedOrigin))
}
sd := subdomain{prefix: scheme + "://", suffix: host}
allowSubOrigins = append(allowSubOrigins, sd)
} else {
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CORS] Invalid origin format in configuration: " + maskValue(trimmedOrigin))
}
allowOrigins = append(allowOrigins, normalizedOrigin)
}
}
// Validate CORS credentials configuration
if cfg.AllowCredentials && allowAllOrigins {
panic("[CORS] Configuration error: When 'AllowCredentials' is set to true, 'AllowOrigins' cannot contain a wildcard origin '*'. Please specify allowed origins explicitly or adjust 'AllowCredentials' setting.")
}
// Warn if allowAllOrigins is set to true and AllowOriginsFunc is defined
if allowAllOrigins && cfg.AllowOriginsFunc != nil {
log.Warn("[CORS] 'AllowOrigins' is set to allow all origins, 'AllowOriginsFunc' will not be used.")
}
// Convert int to string
maxAge := strconv.Itoa(cfg.MaxAge)
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get origin header preserving the original case for the response
originHeaderRaw := c.Get(fiber.HeaderOrigin)
originHeader := utilsstrings.ToLower(originHeaderRaw)
// If the request does not have Origin header, the request is outside the scope of CORS
if originHeader == "" {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
// Unless all origins are allowed, we include the Vary header to cache the response correctly
if !allowAllOrigins {
c.Vary(fiber.HeaderOrigin)
}
return c.Next()
}
// If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS
if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
// Response to OPTIONS request should not be cached but,
// some caching can be configured to cache such responses.
// To Avoid poisoning the cache, we include the Vary header
// for non-CORS OPTIONS requests:
c.Vary(fiber.HeaderOrigin)
return c.Next()
}
// Set default allowOrigin to empty string
allowOrigin := ""
// Check allowed origins
if allowAllOrigins {
allowOrigin = "*"
} else {
// Check if the origin is in the list of allowed origins
if slices.Contains(allowOrigins, originHeader) {
allowOrigin = originHeaderRaw
}
// Check if the origin is in the list of allowed subdomains
if allowOrigin == "" {
for _, sOrigin := range allowSubOrigins {
if sOrigin.match(originHeader) {
allowOrigin = originHeaderRaw
break
}
}
}
}
// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeaderRaw) {
originIsSerialized, originIsNull := isOriginSerializedOrNull(originHeaderRaw)
if originIsSerialized || originIsNull {
allowOrigin = originHeaderRaw
}
}
// Simple request
// Omit allowMethods and allowHeaders, only used for pre-flight requests
if c.Method() != fiber.MethodOptions {
if !allowAllOrigins {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
c.Vary(fiber.HeaderOrigin)
}
setSimpleHeaders(c, allowOrigin, &cfg)
return c.Next()
}
// Pre-flight request
// Response to OPTIONS request should not be cached but,
// some caching can be configured to cache such responses.
// To Avoid poisoning the cache, we include the Vary header
// of preflight responses:
c.Vary(fiber.HeaderAccessControlRequestMethod)
c.Vary(fiber.HeaderAccessControlRequestHeaders)
if cfg.AllowPrivateNetwork && c.Get(fiber.HeaderAccessControlRequestPrivateNetwork) == "true" {
c.Vary(fiber.HeaderAccessControlRequestPrivateNetwork)
c.Set(fiber.HeaderAccessControlAllowPrivateNetwork, "true")
}
c.Vary(fiber.HeaderOrigin)
setPreflightHeaders(c, allowOrigin, maxAge, &cfg)
// Set Preflight headers
if len(cfg.AllowMethods) > 0 {
c.Set(fiber.HeaderAccessControlAllowMethods, strings.Join(cfg.AllowMethods, ", "))
}
if len(cfg.AllowHeaders) > 0 {
c.Set(fiber.HeaderAccessControlAllowHeaders, strings.Join(cfg.AllowHeaders, ", "))
} else {
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
if h != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
}
}
// Send 204 No Content
return c.SendStatus(fiber.StatusNoContent)
}
}
// Function to set Simple CORS headers
func setSimpleHeaders(c fiber.Ctx, allowOrigin string, cfg *Config) {
if cfg == nil {
return
}
if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin == "*" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
} else if allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
} else if allowOrigin != "" {
// For non-credential requests, it's safe to set to '*' or specific origins
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
}
// Set Expose-Headers if not empty
if len(cfg.ExposeHeaders) > 0 {
c.Set(fiber.HeaderAccessControlExposeHeaders, strings.Join(cfg.ExposeHeaders, ", "))
}
}
// Function to set Preflight CORS headers
func setPreflightHeaders(c fiber.Ctx, allowOrigin, maxAge string, cfg *Config) {
setSimpleHeaders(c, allowOrigin, cfg)
// Set MaxAge if set
if cfg != nil && cfg.MaxAge > 0 {
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
} else if cfg != nil && cfg.MaxAge < 0 {
c.Set(fiber.HeaderAccessControlMaxAge, "0")
}
}
================================================
FILE: middleware/cors/cors_test.go
================================================
package cors
import (
"bytes"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func Test_CORS_Defaults(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
testDefaultOrEmptyConfig(t, app)
}
func Test_CORS_Empty_Config(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{}))
testDefaultOrEmptyConfig(t, app)
}
func Test_CORS_WildcardHeaders(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
AllowMethods: []string{"*"},
AllowHeaders: []string{"*"},
ExposeHeaders: []string{"*"},
}))
h := app.Handler()
// Test preflight request
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)))
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
}
func Test_CORS_Negative_MaxAge(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{MaxAge: -1}))
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
app.Handler()(ctx)
require.Equal(t, "0", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
}
func Test_CORS_MaxAge_NotSetOnSimpleRequest(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{MaxAge: 100}))
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
app.Handler()(ctx)
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
}
func Test_CORS_Preserve_Origin_Case(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{AllowOrigins: []string{"http://example.com"}}))
origin := "HTTP://EXAMPLE.COM"
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, origin)
app.Handler()(ctx)
require.Equal(t, origin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
t.Helper()
h := app.Handler()
// Test default GET response headers
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
// Test default OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)
require.Equal(t, "GET, POST, HEAD, PUT, DELETE, PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)))
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
}
func Test_CORS_AllowOrigins_Vary(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(
Config{
AllowOrigins: []string{"http://localhost"},
},
))
h := app.Handler()
// Test Vary header non-Cors request
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set")
// Test Vary header Cors request
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)
require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set")
}
// go test -run -v Test_CORS_Wildcard
func Test_CORS_Wildcard(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
// OPTIONS (preflight) response headers when AllowOrigins is *
app.Use(New(Config{
MaxAge: 3600,
ExposeHeaders: []string{"X-Request-ID"},
AllowHeaders: []string{"Authentication"},
}))
// Get handler pointer
handler := app.Handler()
// Make request
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
// Perform request
handler(ctx)
// Check result
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) // Validates request is not reflecting origin in the response
require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set")
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
require.Equal(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
handler(ctx)
require.NotContains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should not be set")
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
}
// go test -run -v Test_CORS_Origin_AllowCredentials
func Test_CORS_Origin_AllowCredentials(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
// OPTIONS (preflight) response headers when AllowOrigins is *
app.Use(New(Config{
AllowOrigins: []string{"http://localhost"},
AllowCredentials: true,
MaxAge: 3600,
ExposeHeaders: []string{"X-Request-ID"},
AllowHeaders: []string{"Authentication"},
}))
// Get handler pointer
handler := app.Handler()
// Make request
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
// Perform request
handler(ctx)
// Check result
require.Equal(t, "http://localhost", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
require.Equal(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
require.Equal(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodGet)
handler(ctx)
require.Equal(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
}
// go test -run -v Test_CORS_Wildcard_AllowCredentials_Panic
// Test for fiber-ghsa-fmg4-x8pw-hjhg
func Test_CORS_Wildcard_AllowCredentials_Panic(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
didPanic := false
func() {
defer func() {
if r := recover(); r != nil {
didPanic = true
}
}()
app.Use(New(Config{
AllowOrigins: []string{"*"},
AllowCredentials: true,
}))
}()
if !didPanic {
t.Error("Expected a panic when AllowOrigins is '*' and AllowCredentials is true")
}
}
// Test that a warning is logged when AllowOrigins allows all origins and
// AllowOriginsFunc is also provided.
func Test_CORS_Warn_AllowAllOrigins_WithFunc(t *testing.T) {
var buf bytes.Buffer
log.SetOutput(&buf)
t.Cleanup(func() { log.SetOutput(os.Stderr) })
fiber.New().Use(New(Config{
AllowOrigins: []string{"*"},
AllowOriginsFunc: func(string) bool { return true },
}))
require.Contains(t, buf.String(), "AllowOriginsFunc' will not be used")
}
// go test -run -v Test_CORS_Invalid_Origin_Panic
func Test_CORS_Invalid_Origins_Panic(t *testing.T) {
t.Parallel()
invalidOrigins := []string{
"localhost",
"http://foo.[a-z]*.example.com",
"http://*",
"https://*",
"http://*.com*",
"invalid url",
"*",
"http://origin.com,invalid url",
// add more invalid origins as needed
}
for _, origin := range invalidOrigins {
// New fiber instance
app := fiber.New()
didPanic := false
func() {
defer func() {
if r := recover(); r != nil {
didPanic = true
}
}()
app.Use(New(Config{
AllowOrigins: []string{origin},
AllowCredentials: true,
}))
}()
if !didPanic {
t.Errorf("Expected a panic for invalid origin: %s", origin)
}
}
}
func Test_CORS_DisableValueRedaction(t *testing.T) {
t.Parallel()
require.PanicsWithValue(t, "[CORS] Invalid origin format in configuration: [redacted]", func() {
New(Config{
AllowOrigins: []string{"http://"},
DisableValueRedaction: false,
})
})
require.PanicsWithValue(t, "[CORS] Invalid origin format in configuration: http://", func() {
New(Config{
AllowOrigins: []string{"http://"},
DisableValueRedaction: true,
})
})
}
// go test -run -v Test_CORS_Subdomain
func Test_CORS_Subdomain(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
// OPTIONS (preflight) response headers when AllowOrigins is set to a subdomain
app.Use("/", New(Config{
AllowOrigins: []string{" http://*.example.com "},
}))
// Get handler pointer
handler := app.Handler()
// Make request with disallowed origin
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
handler(ctx)
// Allow-Origin header should be "" because http://google.com does not satisfy http://*.example.com
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with domain only (disallowed)
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
handler(ctx)
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with malformed subdomain (disallowed)
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://evil.comexample.com")
handler(ctx)
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com")
handler(ctx)
require.Equal(t, "http://test.example.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
func Test_CORS_AllowOriginScheme(t *testing.T) {
t.Parallel()
tests := []struct {
reqOrigin string
pattern []string
shouldAllowOrigin bool
}{
{
pattern: []string{"http://example.com"},
reqOrigin: "http://example.com",
shouldAllowOrigin: true,
},
{
pattern: []string{"HTTP://EXAMPLE.COM"},
reqOrigin: "http://example.com",
shouldAllowOrigin: true,
},
{
pattern: []string{"https://example.com"},
reqOrigin: "https://example.com",
shouldAllowOrigin: true,
},
{
pattern: []string{"http://example.com"},
reqOrigin: "https://example.com",
shouldAllowOrigin: false,
},
{
pattern: []string{"http://*.example.com"},
reqOrigin: "http://aaa.example.com",
shouldAllowOrigin: true,
},
{
pattern: []string{"http://*.example.com"},
reqOrigin: "http://bbb.aaa.example.com",
shouldAllowOrigin: true,
},
{
pattern: []string{"http://*.aaa.example.com"},
reqOrigin: "http://bbb.aaa.example.com",
shouldAllowOrigin: true,
},
{
pattern: []string{"http://*.example.com:8080"},
reqOrigin: "http://aaa.example.com:8080",
shouldAllowOrigin: true,
},
{
pattern: []string{"http://*.example.com"},
reqOrigin: "http://1.2.aaa.example.com",
shouldAllowOrigin: true,
},
{
pattern: []string{"http://example.com"},
reqOrigin: "http://gofiber.com",
shouldAllowOrigin: false,
},
{
pattern: []string{"http://*.aaa.example.com"},
reqOrigin: "http://ccc.bbb.example.com",
shouldAllowOrigin: false,
},
{
pattern: []string{"http://*.example.com"},
reqOrigin: "http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com",
shouldAllowOrigin: true,
},
{
pattern: []string{"http://example.com"},
reqOrigin: "http://ccc.bbb.example.com",
shouldAllowOrigin: false,
},
{
pattern: []string{"https://--aaa.bbb.com"},
reqOrigin: "https://prod-preview--aaa.bbb.com",
shouldAllowOrigin: false,
},
{
pattern: []string{"http://*.example.com"},
reqOrigin: "http://ccc.bbb.example.com",
shouldAllowOrigin: true,
},
{
pattern: []string{"http://domain-1.com", "http://example.com"},
reqOrigin: "http://example.com",
shouldAllowOrigin: true,
},
{
pattern: []string{"http://domain-1.com", "http://example.com"},
reqOrigin: "http://domain-2.com",
shouldAllowOrigin: false,
},
{
pattern: []string{"http://domain-1.com", "http://example.com"},
reqOrigin: "http://example.com",
shouldAllowOrigin: true,
},
{
pattern: []string{"http://domain-1.com", "http://example.com"},
reqOrigin: "http://domain-2.com",
shouldAllowOrigin: false,
},
{
pattern: []string{"http://domain-1.com", "http://example.com"},
reqOrigin: "http://domain-1.com",
shouldAllowOrigin: true,
},
}
for _, tt := range tests {
app := fiber.New()
app.Use("/", New(Config{AllowOrigins: tt.pattern}))
handler := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin)
handler(ctx)
if tt.shouldAllowOrigin {
require.Equal(t, tt.reqOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
} else {
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
}
}
func Test_CORS_AllowOriginHeader_NoMatch(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
app.Use("/", New(Config{
AllowOrigins: []string{"http://example-1.com", "https://example-1.com"},
}))
// Get handler pointer
handler := app.Handler()
// Make request with disallowed origin
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
handler(ctx)
var headerExists bool
for key := range ctx.Response.Header.All() {
if string(key) == fiber.HeaderAccessControlAllowOrigin {
headerExists = true
}
}
require.False(t, headerExists, "Access-Control-Allow-Origin header should not be set")
}
// go test -run Test_CORS_Next
func Test_CORS_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
// go test -run Test_CORS_Headers_BasedOnRequestType
func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
methods := []string{
fiber.MethodGet,
fiber.MethodPost,
fiber.MethodPut,
fiber.MethodDelete,
fiber.MethodPatch,
fiber.MethodHead,
}
// Get handler pointer
handler := app.Handler()
t.Run("Without origin", func(t *testing.T) {
t.Parallel()
// Make request without origin header, and without Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
handler(ctx)
require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})
t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Parallel()
// Make preflight request with origin header and with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
handler(ctx)
require.Equal(t, 204, ctx.Response.StatusCode(), "Status code should be 204")
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
require.Equal(t, "GET, POST, HEAD, PUT, DELETE, PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should be set (preflight request)")
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should be set (preflight request)")
}
})
t.Run("Non-preflight request with origin", func(t *testing.T) {
t.Parallel()
// Make non-preflight request with origin header and with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/api/action")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
handler(ctx)
require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should not be set (non-preflight request)")
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should not be set (non-preflight request)")
}
})
t.Run("Preflight with Access-Control-Request-Headers", func(t *testing.T) {
t.Parallel()
// Make preflight request with origin header and with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestHeaders, "X-Custom-Header")
handler(ctx)
require.Equal(t, 204, ctx.Response.StatusCode(), "Status code should be 204")
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
require.Equal(t, "GET, POST, HEAD, PUT, DELETE, PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should be set (preflight request)")
require.Equal(t, "X-Custom-Header", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should be set (preflight request)")
}
})
}
func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
app.Use("/", New(Config{
AllowOrigins: []string{"http://example-1.com"},
AllowOriginsFunc: func(origin string) bool {
return strings.Contains(origin, "example-2")
},
}))
// Get handler pointer
handler := app.Handler()
// Make request with disallowed origin
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
handler(ctx)
// Allow-Origin header should be "" because http://google.com does not satisfy http://example-1.com or 'strings.Contains(origin, "example-2")'
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com")
handler(ctx)
require.Equal(t, "http://example-1.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
handler(ctx)
require.Equal(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
func Test_CORS_AllowOriginsFunc(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
app.Use("/", New(Config{
AllowOriginsFunc: func(origin string) bool {
return strings.Contains(origin, "example-2")
},
}))
// Get handler pointer
handler := app.Handler()
// Make request with disallowed origin
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
handler(ctx)
// Allow-Origin header should be empty because http://google.com does not satisfy 'strings.Contains(origin, "example-2")'
// and AllowOrigins has not been set
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")
handler(ctx)
// Allow-Origin header should be "http://example-2.com"
require.Equal(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
func Test_CORS_AllowOriginsFuncRejectsNonSerializedOrigins(t *testing.T) {
t.Parallel()
testCases := []struct {
Name string
Origin string
Method string
SetPreflightMethod bool
ExpectAllowed bool
}{
{
Name: "UserInfoPresent",
Origin: "http://user:pass@example.com",
Method: fiber.MethodGet,
},
{
Name: "PathPresent",
Origin: "http://example.com/path",
Method: fiber.MethodOptions,
SetPreflightMethod: true,
},
{
Name: "QueryPresent",
Origin: "http://example.com?query=1",
Method: fiber.MethodOptions,
SetPreflightMethod: true,
},
{
Name: "FragmentPresent",
Origin: "http://example.com#section",
Method: fiber.MethodGet,
},
{
Name: "WildcardHost",
Origin: "http://*.example.com",
Method: fiber.MethodGet,
},
{
Name: "StandaloneWildcard",
Origin: "*",
Method: fiber.MethodGet,
},
{
Name: "NullOriginUppercase",
Origin: "NULL",
Method: fiber.MethodGet,
},
{
Name: "NullOriginMixedCase",
Origin: "Null",
Method: fiber.MethodGet,
},
{
Name: "NullOriginLowercase",
Origin: "null",
Method: fiber.MethodGet,
ExpectAllowed: true,
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use("/", New(Config{
AllowOriginsFunc: func(string) bool { return true },
}))
app.All("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
handler := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(tc.Method)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.Origin)
if tc.SetPreflightMethod {
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
}
handler(ctx)
if tc.ExpectAllowed {
require.Equal(t, "null", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
} else {
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
if tc.Method == fiber.MethodOptions {
require.Equal(t, fiber.StatusNoContent, ctx.Response.StatusCode())
} else {
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
}
})
}
}
func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
testCases := []struct {
Name string
RequestOrigin string
ResponseOrigin string
Config Config
}{
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginAllowed",
Config: Config{
AllowOrigins: []string{"http://aaa.com"},
AllowOriginsFunc: nil,
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/OriginAllowed",
Config: Config{
AllowOrigins: []string{"http://aaa.com", "http://bbb.com"},
AllowOriginsFunc: nil,
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "http://bbb.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/OriginNotAllowed",
Config: Config{
AllowOrigins: []string{"http://aaa.com", "http://bbb.com"},
AllowOriginsFunc: nil,
},
RequestOrigin: "http://ccc.com",
ResponseOrigin: "",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/Whitespace/OriginAllowed",
Config: Config{
AllowOrigins: []string{" http://aaa.com ", " http://bbb.com "},
AllowOriginsFunc: nil,
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginNotAllowed",
Config: Config{
AllowOrigins: []string{"http://aaa.com"},
AllowOriginsFunc: nil,
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginAllowed",
Config: Config{
AllowOrigins: []string{"http://aaa.com"},
AllowOriginsFunc: func(_ string) bool {
return true
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginNotAllowed",
Config: Config{
AllowOrigins: []string{"http://aaa.com"},
AllowOriginsFunc: func(_ string) bool {
return true
},
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "http://bbb.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginAllowed",
Config: Config{
AllowOrigins: []string{"http://aaa.com"},
AllowOriginsFunc: func(_ string) bool {
return false
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginNotAllowed",
Config: Config{
AllowOrigins: []string{"http://aaa.com"},
AllowOriginsFunc: func(_ string) bool {
return false
},
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
},
{
Name: "AllowOriginsEmpty/AllowOriginsFuncUndefined/OriginAllowed",
Config: Config{
AllowOriginsFunc: nil,
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "*",
},
{
Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsTrue/OriginAllowed",
Config: Config{
AllowOriginsFunc: func(_ string) bool {
return true
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsFalse/OriginNotAllowed",
Config: Config{
AllowOriginsFunc: func(_ string) bool {
return false
},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "",
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
app := fiber.New()
app.Use("/", New(tc.Config))
handler := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)
handler(ctx)
require.Equal(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
})
}
}
// The fix for issue #2422
func Test_CORS_AllowCredentials(t *testing.T) {
testCases := []struct {
Name string
RequestOrigin string
ResponseOrigin string
ResponseCredentials string
Config Config
}{
{
Name: "AllowOriginsFuncDefined",
Config: Config{
AllowCredentials: true,
AllowOriginsFunc: func(_ string) bool {
return true
},
},
RequestOrigin: "http://aaa.com",
// The AllowOriginsFunc config was defined, should use the real origin of the function
ResponseOrigin: "http://aaa.com",
ResponseCredentials: "true",
},
{
Name: "fiber-ghsa-fmg4-x8pw-hjhg-wildcard-credentials",
Config: Config{
AllowCredentials: true,
AllowOriginsFunc: func(_ string) bool {
return true
},
},
RequestOrigin: "*",
ResponseOrigin: "",
// Middleware will validate that wildcard won't set credentials to true and reject non-serialized origins
ResponseCredentials: "",
},
{
Name: "AllowOriginsFuncNotDefined",
Config: Config{
// Setting this to true will cause the middleware to panic since default AllowOrigins is "*"
AllowCredentials: false,
},
RequestOrigin: "http://aaa.com",
// None of the AllowOrigins or AllowOriginsFunc config was defined, should use the default origin of "*"
// which will cause the CORS error in the client:
// The value of the 'Access-Control-Allow-Origin' header in the response must not be the wildcard '*'
// when the request's credentials mode is 'include'.
ResponseOrigin: "*",
ResponseCredentials: "",
},
{
Name: "AllowOriginsDefined",
Config: Config{
AllowCredentials: true,
AllowOrigins: []string{"http://aaa.com"},
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
ResponseCredentials: "true",
},
{
Name: "AllowOriginsDefined/UnallowedOrigin",
Config: Config{
AllowCredentials: true,
AllowOrigins: []string{"http://aaa.com"},
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "",
ResponseCredentials: "",
},
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
app := fiber.New()
app.Use("/", New(tc.Config))
handler := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)
handler(ctx)
require.Equal(t, tc.ResponseCredentials, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
})
}
}
// The Enhancement for issue #2804
func Test_CORS_AllowPrivateNetwork(t *testing.T) {
t.Parallel()
// Test scenario where AllowPrivateNetwork is enabled
app := fiber.New()
app.Use(New(Config{
AllowPrivateNetwork: true,
}))
handler := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set("Access-Control-Request-Private-Network", "true")
handler(ctx)
// Verify the Access-Control-Allow-Private-Network header is set to "true"
require.Equal(t, "true", string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should be set to 'true' when AllowPrivateNetwork is enabled")
// Non-preflight request should not have Access-Control-Allow-Private-Network header
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com")
ctx.Request.Header.Set("Access-Control-Request-Private-Network", "true")
handler(ctx)
require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should be set to 'true' when AllowPrivateNetwork is enabled")
// Non-preflight GET request should not have Access-Control-Allow-Private-Network header
require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should be set to 'true' when AllowPrivateNetwork is enabled")
// Non-preflight OPTIONS request should not have Access-Control-Allow-Private-Network header
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com")
ctx.Request.Header.Set("Access-Control-Request-Private-Network", "true")
handler(ctx)
require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should be set to 'true' when AllowPrivateNetwork is enabled")
// Reset ctx for next test
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com")
// Test scenario where AllowPrivateNetwork is disabled (default)
app = fiber.New()
app.Use(New())
handler = app.Handler()
handler(ctx)
// Verify the Access-Control-Allow-Private-Network header is not present
require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should not be present by default")
// Test scenario where AllowPrivateNetwork is disabled but client sends header
app = fiber.New()
app.Use(New())
handler = app.Handler()
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com")
ctx.Request.Header.Set("Access-Control-Request-Private-Network", "true")
handler(ctx)
// Verify the Access-Control-Allow-Private-Network header is not present
require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should not be present by default")
// Test scenario where AllowPrivateNetwork is enabled and client does NOT send header
app = fiber.New()
app.Use(New(Config{
AllowPrivateNetwork: true,
}))
handler = app.Handler()
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com")
handler(ctx)
// Verify the Access-Control-Allow-Private-Network header is not present
require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should not be present by default")
// Test scenario where AllowPrivateNetwork is enabled and client sends header with false value
app = fiber.New()
app.Use(New(Config{
AllowPrivateNetwork: true,
}))
handler = app.Handler()
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com")
ctx.Request.Header.Set("Access-Control-Request-Private-Network", "false")
handler(ctx)
// Verify the Access-Control-Allow-Private-Network header is not present
require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should not be present by default")
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandler -benchmem -count=4
func Benchmark_CORS_NewHandler(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: []string{"http://localhost", "http://example.com"},
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete},
AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept},
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
b.ReportAllocs()
for b.Loop() {
h(ctx)
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandler_Parallel -benchmem -count=4
func Benchmark_CORS_NewHandler_Parallel(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: []string{"http://localhost", "http://example.com"},
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete},
AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept},
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
for pb.Next() {
h(ctx)
}
})
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerSingleOrigin -benchmem -count=4
func Benchmark_CORS_NewHandlerSingleOrigin(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: []string{"http://example.com"},
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete},
AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept},
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
b.ReportAllocs()
for b.Loop() {
h(ctx)
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerSingleOrigin_Parallel -benchmem -count=4
func Benchmark_CORS_NewHandlerSingleOrigin_Parallel(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: []string{"http://example.com"},
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete},
AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept},
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
for pb.Next() {
h(ctx)
}
})
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerWildcard -benchmem -count=4
func Benchmark_CORS_NewHandlerWildcard(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete},
AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept},
AllowCredentials: false,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
b.ReportAllocs()
for b.Loop() {
h(ctx)
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerWildcard_Parallel -benchmem -count=4
func Benchmark_CORS_NewHandlerWildcard_Parallel(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete},
AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept},
AllowCredentials: false,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
for pb.Next() {
h(ctx)
}
})
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflight -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflight(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: []string{"http://localhost", "http://example.com"},
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete},
AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept},
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Preflight request
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
b.ReportAllocs()
for b.Loop() {
h(ctx)
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflight_Parallel -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflight_Parallel(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: []string{"http://localhost", "http://example.com"},
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete},
AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept},
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
for pb.Next() {
h(ctx)
}
})
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightSingleOrigin -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflightSingleOrigin(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: []string{"http://example.com"},
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete},
AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept},
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
b.ReportAllocs()
for b.Loop() {
h(ctx)
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightSingleOrigin_Parallel -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflightSingleOrigin_Parallel(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: []string{"http://example.com"},
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete},
AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept},
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
for pb.Next() {
h(ctx)
}
})
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightWildcard -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflightWildcard(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete},
AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept},
AllowCredentials: false,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
b.ReportAllocs()
for b.Loop() {
h(ctx)
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightWildcard_Parallel -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflightWildcard_Parallel(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete},
AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept},
AllowCredentials: false,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
for pb.Next() {
h(ctx)
}
})
}
================================================
FILE: middleware/cors/utils.go
================================================
package cors
import (
"net/url"
"strings"
utilsstrings "github.com/gofiber/utils/v2/strings"
)
// matchScheme compares the scheme of the domain and pattern
func matchScheme(domain, pattern string) bool {
dScheme, _, dFound := strings.Cut(domain, ":")
pScheme, _, pFound := strings.Cut(pattern, ":")
return dFound && pFound && dScheme == pScheme
}
// normalizeDomain removes the scheme and port from the input domain
func normalizeDomain(input string) string {
// Remove scheme
if after, found := strings.CutPrefix(input, "https://"); found {
input = after
} else if after, found := strings.CutPrefix(input, "http://"); found {
input = after
}
// Find and remove port, if present
if input != "" && input[0] != '[' {
if before, _, found := strings.Cut(input, ":"); found {
input = before
}
}
return input
}
// normalizeOrigin checks if the provided origin is in a correct format
// and normalizes it by removing any path or trailing slash.
// It returns a boolean indicating whether the origin is valid
// and the normalized origin.
func normalizeOrigin(origin string) (valid bool, normalized string) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming validity and normalized origin results
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false, ""
}
// Don't allow a wildcard with a protocol
// wildcards cannot be used within any other value. For example, the following header is not valid:
// Access-Control-Allow-Origin: https://*
if strings.IndexByte(parsedOrigin.Host, '*') >= 0 {
return false, ""
}
// Validate there is a host present. The presence of a path, query, or fragment components
// is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized
if parsedOrigin.User != nil ||
parsedOrigin.Host == "" ||
(parsedOrigin.Path != "" && parsedOrigin.Path != "/") ||
parsedOrigin.RawQuery != "" ||
parsedOrigin.Fragment != "" {
return false, ""
}
// Normalize the origin by constructing it from the scheme and host.
// The path or trailing slash is not included in the normalized origin.
return true, utilsstrings.ToLower(parsedOrigin.Scheme) + "://" + utilsstrings.ToLower(parsedOrigin.Host)
}
type subdomain struct {
// The wildcard pattern
prefix string
suffix string
}
func (s subdomain) match(o string) bool {
// Not a subdomain if not long enough for a dot separator.
if len(o) < len(s.prefix)+len(s.suffix)+1 {
return false
}
if !strings.HasPrefix(o, s.prefix) || !strings.HasSuffix(o, s.suffix) {
return false
}
// Check for the dot separator and validate that there is at least one
// non-empty label between prefix and suffix. Empty labels like
// "https://.example.com" or "https://..example.com" should not match.
suffixStartIndex := len(o) - len(s.suffix)
if suffixStartIndex <= len(s.prefix) {
return false
}
if o[suffixStartIndex-1] != '.' {
return false
}
// Extract the subdomain part (without the trailing dot) and ensure it
// doesn't contain empty labels.
sub := o[len(s.prefix) : suffixStartIndex-1]
if sub == "" || strings.HasPrefix(sub, ".") || strings.Contains(sub, "..") {
return false
}
return true
}
================================================
FILE: middleware/cors/utils_test.go
================================================
package cors
import (
"testing"
"github.com/stretchr/testify/assert"
)
// go test -run -v Test_NormalizeOrigin
func Test_NormalizeOrigin(t *testing.T) {
testCases := []struct {
origin string
expectedOrigin string
expectedValid bool
}{
{origin: "http://example.com", expectedValid: true, expectedOrigin: "http://example.com"}, // Simple case should work.
{origin: "http://example.com/", expectedValid: true, expectedOrigin: "http://example.com"}, // Trailing slash should be removed.
{origin: "http://example.com:3000", expectedValid: true, expectedOrigin: "http://example.com:3000"}, // Port should be preserved.
{origin: "http://example.com:3000/", expectedValid: true, expectedOrigin: "http://example.com:3000"}, // Trailing slash should be removed.
{origin: "app://example.com/", expectedValid: true, expectedOrigin: "app://example.com"}, // App scheme should be accepted.
{origin: "http://", expectedValid: false, expectedOrigin: ""}, // Invalid origin should not be accepted.
{origin: "file:///etc/passwd", expectedValid: false, expectedOrigin: ""}, // File scheme should not be accepted.
{origin: "https://*example.com", expectedValid: false, expectedOrigin: ""}, // Wildcard domain should not be accepted.
{origin: "http://*.example.com", expectedValid: false, expectedOrigin: ""}, // Wildcard subdomain should not be accepted.
{origin: "http://example.com/path", expectedValid: false, expectedOrigin: ""}, // Path should not be accepted.
{origin: "http://example.com?query=123", expectedValid: false, expectedOrigin: ""}, // Query should not be accepted.
{origin: "http://example.com#fragment", expectedValid: false, expectedOrigin: ""}, // Fragment should not be accepted.
{origin: "http://user:pass@example.com", expectedValid: false, expectedOrigin: ""}, // Userinfo should not be accepted.
{origin: "http://localhost", expectedValid: true, expectedOrigin: "http://localhost"}, // Localhost should be accepted.
{origin: "http://127.0.0.1", expectedValid: true, expectedOrigin: "http://127.0.0.1"}, // IPv4 address should be accepted.
{origin: "http://[::1]", expectedValid: true, expectedOrigin: "http://[::1]"}, // IPv6 address should be accepted.
{origin: "http://[::1]:8080", expectedValid: true, expectedOrigin: "http://[::1]:8080"}, // IPv6 address with port should be accepted.
{origin: "http://[::1]:8080/", expectedValid: true, expectedOrigin: "http://[::1]:8080"}, // IPv6 address with port and trailing slash should be accepted.
{origin: "http://[::1]:8080/path", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and path should not be accepted.
{origin: "http://[::1]:8080?query=123", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and query should not be accepted.
{origin: "http://[::1]:8080#fragment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and fragment should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, and fragment should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment/", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, and trailing slash should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment/invalid", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment/invalid/", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with trailing slash should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment/invalid/segment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with additional segment should not be accepted.
}
for _, tc := range testCases {
valid, normalizedOrigin := normalizeOrigin(tc.origin)
if valid != tc.expectedValid {
t.Errorf("Expected origin '%s' to be valid: %v, but got: %v", tc.origin, tc.expectedValid, valid)
}
if normalizedOrigin != tc.expectedOrigin {
t.Errorf("Expected normalized origin '%s' for origin '%s', but got: '%s'", tc.expectedOrigin, tc.origin, normalizedOrigin)
}
}
}
// go test -run -v Test_MatchScheme
func Test_MatchScheme(t *testing.T) {
testCases := []struct {
domain string
pattern string
expected bool
}{
{domain: "http://example.com", pattern: "http://example.com", expected: true}, // Exact match should work.
{domain: "https://example.com", pattern: "http://example.com", expected: false}, // Scheme mismatch should matter.
{domain: "http://example.com", pattern: "https://example.com", expected: false}, // Scheme mismatch should matter.
{domain: "http://example.com", pattern: "http://example.org", expected: true}, // Different domains should not matter.
{domain: "http://example.com", pattern: "http://example.com:8080", expected: true}, // Port should not matter.
{domain: "http://example.com:8080", pattern: "http://example.com", expected: true}, // Port should not matter.
{domain: "http://example.com:8080", pattern: "http://example.com:8081", expected: true}, // Different ports should not matter.
{domain: "http://localhost", pattern: "http://localhost", expected: true}, // Localhost should match.
{domain: "http://127.0.0.1", pattern: "http://127.0.0.1", expected: true}, // IPv4 address should match.
{domain: "http://[::1]", pattern: "http://[::1]", expected: true}, // IPv6 address should match.
}
for _, tc := range testCases {
result := matchScheme(tc.domain, tc.pattern)
if result != tc.expected {
t.Errorf("Expected matchScheme('%s', '%s') to be %v, but got %v", tc.domain, tc.pattern, tc.expected, result)
}
}
}
// go test -run -v Test_NormalizeDomain
func Test_NormalizeDomain(t *testing.T) {
testCases := []struct {
input string
expectedOutput string
}{
{input: "http://example.com", expectedOutput: "example.com"}, // Simple case with http scheme.
{input: "https://example.com", expectedOutput: "example.com"}, // Simple case with https scheme.
{input: "http://example.com:3000", expectedOutput: "example.com"}, // Case with port.
{input: "https://example.com:3000", expectedOutput: "example.com"}, // Case with port and https scheme.
{input: "http://example.com/path", expectedOutput: "example.com/path"}, // Case with path.
{input: "http://example.com?query=123", expectedOutput: "example.com?query=123"}, // Case with query.
{input: "http://example.com#fragment", expectedOutput: "example.com#fragment"}, // Case with fragment.
{input: "example.com", expectedOutput: "example.com"}, // Case without scheme.
{input: "example.com:8080", expectedOutput: "example.com"}, // Case without scheme but with port.
{input: "sub.example.com", expectedOutput: "sub.example.com"}, // Case with subdomain.
{input: "sub.sub.example.com", expectedOutput: "sub.sub.example.com"}, // Case with nested subdomain.
{input: "http://localhost", expectedOutput: "localhost"}, // Case with localhost.
{input: "http://127.0.0.1", expectedOutput: "127.0.0.1"}, // Case with IPv4 address.
{input: "http://[::1]", expectedOutput: "[::1]"}, // Case with IPv6 address.
}
for _, tc := range testCases {
output := normalizeDomain(tc.input)
if output != tc.expectedOutput {
t.Errorf("Expected normalized domain '%s' for input '%s', but got: '%s'", tc.expectedOutput, tc.input, output)
}
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_SubdomainMatch -benchmem -count=4
func Benchmark_CORS_SubdomainMatch(b *testing.B) {
s := subdomain{
prefix: "www",
suffix: "example.com",
}
o := "www.example.com"
b.ReportAllocs()
for b.Loop() {
s.match(o)
}
}
func Test_CORS_SubdomainMatch(t *testing.T) {
tests := []struct {
name string
sub subdomain
origin string
expected bool
}{
{
name: "match with different scheme",
sub: subdomain{prefix: "http://api.", suffix: "example.com"},
origin: "https://api.service.example.com",
expected: false,
},
{
name: "match with different scheme",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "http://api.service.example.com",
expected: false,
},
{
name: "match with valid subdomain",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "https://api.service.example.com",
expected: true,
},
{
name: "match with valid nested subdomain",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "https://1.2.api.service.example.com",
expected: true,
},
{
name: "no match with invalid prefix",
sub: subdomain{prefix: "https://abc.", suffix: "example.com"},
origin: "https://service.example.com",
expected: false,
},
{
name: "no match with invalid suffix",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "https://api.example.org",
expected: false,
},
{
name: "no match with empty origin",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "",
expected: false,
},
{
name: "no match with malformed subdomain",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "https://evil.comexample.com",
expected: false,
},
{
name: "partial match not considered a match",
sub: subdomain{prefix: "https://service.", suffix: "example.com"},
origin: "https://api.example.com",
expected: false,
},
{
name: "no match with empty host label",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "https://.example.com",
expected: false,
},
{
name: "no match with malformed host label",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "https://..example.com",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.sub.match(tt.origin)
assert.Equal(t, tt.expected, got, "subdomain.match()")
})
}
}
================================================
FILE: middleware/csrf/config.go
================================================
package csrf
import (
"fmt"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/fiber/v3/middleware/session"
"github.com/gofiber/utils/v2"
)
// Config defines the config for CSRF middleware.
type Config struct {
// Storage is used to store the state of the middleware.
//
// Optional. Default: memory.New()
// Ignored if Session is set.
Storage fiber.Storage
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// Session is used to store the state of the middleware.
//
// Optional. Default: nil
// If set, the middleware will use the session store instead of the storage.
Session *session.Store
// KeyGenerator creates a new CSRF token.
//
// Optional. Default: utils.SecureToken
KeyGenerator func() string
// ErrorHandler is executed when an error is returned from fiber.Handler.
//
// Optional. Default: defaultErrorHandler
ErrorHandler fiber.ErrorHandler
// CookieName is the name of the CSRF cookie.
//
// Optional. Default: "csrf_"
CookieName string
// CookieDomain is the domain of the CSRF cookie.
//
// Optional. Default: ""
CookieDomain string
// CookiePath is the path of the CSRF cookie.
//
// Optional. Default: ""
CookiePath string
// CookieSameSite is the SameSite attribute of the CSRF cookie.
//
// Optional. Default: "Lax"
CookieSameSite string
// TrustedOrigins is a list of trusted origins for unsafe requests.
// For requests that use the Origin header, the origin must match the
// Host header or one of the TrustedOrigins.
// For secure requests that do not include the Origin header, the Referer
// header must match the Host header or one of the TrustedOrigins.
//
// This supports matching subdomains at any level. This means you can use a value like
// "https://*.example.com" to allow any subdomain of example.com to submit requests,
// including multiple subdomain levels such as "https://sub.sub.example.com".
//
// Optional. Default: []
TrustedOrigins []string
// Extractor returns the CSRF token from the request.
//
// Optional. Default: extractors.FromHeader("X-Csrf-Token")
//
// Available extractors from github.com/gofiber/fiber/v3/extractors:
// - extractors.FromHeader("X-Csrf-Token"): Most secure, recommended for APIs
// - extractors.FromForm("_csrf"): Secure, recommended for form submissions
// - extractors.FromQuery("csrf_token"): Less secure, URLs may be logged
// - extractors.FromParam("csrf"): Less secure, URLs may be logged
// - extractors.Chain(...): Advanced chaining of multiple extractors
//
// See the Extractors Guide for complete documentation:
// https://docs.gofiber.io/guide/extractors
//
// WARNING: Never create custom extractors that read from cookies with the same
// CookieName as this defeats CSRF protection entirely.
Extractor extractors.Extractor
// IdleTimeout is the duration of time the CSRF token is valid.
//
// Optional. Default: 30 * time.Minute
IdleTimeout time.Duration
// DisableValueRedaction turns off masking CSRF tokens and storage keys in logs and errors.
//
// Optional. Default: false
DisableValueRedaction bool
// CookieSecure indicates if CSRF cookie is secure.
//
// Optional. Default: false
CookieSecure bool
// CookieHTTPOnly indicates if CSRF cookie is HTTP only.
//
// Optional. Default: false
CookieHTTPOnly bool
// CookieSessionOnly decides whether cookie should last for only the browser session.
// Ignores Expiration if set to true.
//
// Optional. Default: false
CookieSessionOnly bool
// SingleUseToken indicates if the CSRF token should be destroyed
// and a new one generated on each use.
//
// Optional. Default: false
SingleUseToken bool
}
// HeaderName is the default header name for CSRF tokens.
const HeaderName = "X-Csrf-Token"
// ConfigDefault is the default config for CSRF middleware.
var ConfigDefault = Config{
CookieName: "csrf_",
CookieSameSite: "Lax",
IdleTimeout: 30 * time.Minute,
KeyGenerator: utils.SecureToken,
ErrorHandler: defaultErrorHandler,
Extractor: extractors.FromHeader(HeaderName),
DisableValueRedaction: false,
}
// defaultErrorHandler is the default error handler that processes errors from fiber.Handler.
func defaultErrorHandler(_ fiber.Ctx, _ error) error {
return fiber.ErrForbidden
}
// configDefault is a helper function to set default values.
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.IdleTimeout <= 0 {
cfg.IdleTimeout = ConfigDefault.IdleTimeout
}
if cfg.CookieName == "" {
cfg.CookieName = ConfigDefault.CookieName
}
if cfg.CookieSameSite == "" {
cfg.CookieSameSite = ConfigDefault.CookieSameSite
}
if cfg.KeyGenerator == nil {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}
if cfg.ErrorHandler == nil {
cfg.ErrorHandler = ConfigDefault.ErrorHandler
}
// Check if Extractor is zero value (since it's a struct)
if cfg.Extractor.Extract == nil {
cfg.Extractor = ConfigDefault.Extractor
}
// Validate extractor security configurations
validateExtractorSecurity(&cfg)
return cfg
}
// validateExtractorSecurity checks for insecure extractor configurations
func validateExtractorSecurity(cfg *Config) {
if cfg == nil {
return
}
// Check primary extractor
if isInsecureCookieExtractor(cfg.Extractor, cfg.CookieName) {
panic("CSRF: Extractor reads from the same cookie '" + cfg.CookieName +
"' used for token storage. This completely defeats CSRF protection.")
}
// Check chained extractors
for i, extractor := range cfg.Extractor.Chain {
if isInsecureCookieExtractor(extractor, cfg.CookieName) {
panic(fmt.Sprintf("CSRF: Chained extractor #%d reads from the same cookie '%s' "+
"used for token storage. This completely defeats CSRF protection.", i+1, cfg.CookieName))
}
}
// Additional security warnings (non-fatal)
if cfg.Extractor.Source == extractors.SourceQuery || cfg.Extractor.Source == extractors.SourceParam {
log.Warnf("[CSRF WARNING] Using %v extractor - URLs may be logged", cfg.Extractor.Source)
}
}
// isInsecureCookieExtractor checks if an extractor unsafely reads from the CSRF cookie
func isInsecureCookieExtractor(extractor extractors.Extractor, cookieName string) bool {
if extractor.Source == extractors.SourceCookie {
// Exact match - definitely insecure
if extractor.Key == cookieName {
return true
}
// Case-insensitive match - potentially confusing, warn but don't panic
if utils.EqualFold(extractor.Key, cookieName) && extractor.Key != cookieName {
log.Warnf("[CSRF WARNING] Extractor cookie name '%s' is similar to CSRF cookie '%s' - this may be confusing",
extractor.Key, cookieName)
}
}
return false
}
================================================
FILE: middleware/csrf/config_test.go
================================================
package csrf
import (
"fmt"
"strings"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
// Test security validation functions
func Test_CSRF_ExtractorSecurity_Validation(t *testing.T) {
t.Parallel()
// Test secure configurations - should not panic
t.Run("SecureConfigurations", func(t *testing.T) {
t.Parallel()
secureConfigs := []Config{
{Extractor: extractors.FromHeader("X-Csrf-Token")},
{Extractor: extractors.FromForm("_csrf")},
{Extractor: extractors.FromQuery("csrf_token")},
{Extractor: extractors.FromParam("csrf")},
{Extractor: extractors.Chain(extractors.FromHeader("X-Csrf-Token"), extractors.FromForm("_csrf"))},
}
for i, cfg := range secureConfigs {
t.Run(fmt.Sprintf("Config%d", i), func(t *testing.T) {
require.NotPanics(t, func() {
configDefault(cfg)
})
})
}
})
// Test insecure configurations - should panic
t.Run("InsecureCookieExtractor", func(t *testing.T) {
t.Parallel()
// Create a custom extractor that reads from cookie (simulating dangerous behavior)
insecureCookieExtractor := extractors.Extractor{
Extract: func(c fiber.Ctx) (string, error) {
return c.Cookies("csrf_"), nil
},
Source: extractors.SourceCookie,
Key: "csrf_",
}
cfg := Config{
CookieName: "csrf_",
Extractor: insecureCookieExtractor,
}
require.Panics(t, func() {
configDefault(cfg)
}, "Should panic when extractor reads from same cookie")
})
// Test insecure chained extractors
t.Run("InsecureChainedExtractor", func(t *testing.T) {
t.Parallel()
insecureCookieExtractor := extractors.Extractor{
Extract: func(c fiber.Ctx) (string, error) {
return c.Cookies("csrf_"), nil
},
Source: extractors.SourceCookie,
Key: "csrf_",
}
chainedExtractor := extractors.Chain(
extractors.FromHeader("X-Csrf-Token"),
insecureCookieExtractor, // This should trigger panic
)
cfg := Config{
CookieName: "csrf_",
Extractor: chainedExtractor,
}
require.Panics(t, func() {
configDefault(cfg)
}, "Should panic when chained extractor reads from same cookie")
})
// Test different cookie names - should be secure
t.Run("DifferentCookieNames", func(t *testing.T) {
t.Parallel()
cookieExtractor := extractors.Extractor{
Extract: func(c fiber.Ctx) (string, error) {
return c.Cookies("different_cookie"), nil
},
Source: extractors.SourceCookie,
Key: "different_cookie",
}
cfg := Config{
CookieName: "csrf_",
Extractor: cookieExtractor,
}
require.NotPanics(t, func() {
configDefault(cfg)
}, "Should not panic when extractor reads from different cookie")
})
}
// Test extractor metadata
func Test_CSRF_Extractor_Metadata(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
expectedKey string
extractor extractors.Extractor
expectedSource extractors.Source
}{
{
name: "FromHeader",
extractor: extractors.FromHeader("X-Custom-Token"),
expectedSource: extractors.SourceHeader,
expectedKey: "X-Custom-Token",
},
{
name: "FromForm",
extractor: extractors.FromForm("_token"),
expectedSource: extractors.SourceForm,
expectedKey: "_token",
},
{
name: "FromQuery",
extractor: extractors.FromQuery("token"),
expectedSource: extractors.SourceQuery,
expectedKey: "token",
},
{
name: "FromParam",
extractor: extractors.FromParam("id"),
expectedSource: extractors.SourceParam,
expectedKey: "id",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.expectedSource, tc.extractor.Source)
require.Equal(t, tc.expectedKey, tc.extractor.Key)
require.NotNil(t, tc.extractor.Extract)
})
}
}
// Test chain extractor metadata
func Test_CSRF_Chain_Extractor_Metadata(t *testing.T) {
t.Parallel()
t.Run("EmptyChain", func(t *testing.T) {
t.Parallel()
chained := extractors.Chain()
require.Equal(t, extractors.SourceCustom, chained.Source)
require.Empty(t, chained.Key)
require.Empty(t, chained.Chain)
})
t.Run("SingleExtractor", func(t *testing.T) {
t.Parallel()
header := extractors.FromHeader("X-Token")
chained := extractors.Chain(header)
require.Equal(t, extractors.SourceHeader, chained.Source)
require.Equal(t, "X-Token", chained.Key)
require.Len(t, chained.Chain, 1)
})
t.Run("MultipleExtractors", func(t *testing.T) {
t.Parallel()
header := extractors.FromHeader("X-Token")
form := extractors.FromForm("_csrf")
chained := extractors.Chain(header, form)
// Should use first extractor's metadata
require.Equal(t, extractors.SourceHeader, chained.Source)
require.Equal(t, "X-Token", chained.Key)
require.Len(t, chained.Chain, 2)
require.Equal(t, header.Source, chained.Chain[0].Source)
require.Equal(t, form.Source, chained.Chain[1].Source)
})
}
// Test custom extractor with new struct pattern
func Test_CSRF_Custom_Extractor_Struct(t *testing.T) {
t.Parallel()
app := fiber.New()
// Custom extractor using new struct pattern
customExtractor := extractors.Extractor{
Extract: func(c fiber.Ctx) (string, error) {
// Extract from custom header
token := c.Get("X-Custom-CSRF")
if token == "" {
return "", extractors.ErrNotFound
}
return token, nil
},
Source: extractors.SourceCustom,
Key: "X-Custom-CSRF",
}
app.Use(New(Config{Extractor: customExtractor}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendString("OK")
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test with custom header
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set("X-Custom-CSRF", token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test without custom header
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
// Test error types for different extractors
func Test_CSRF_Extractor_Error_Types(t *testing.T) {
t.Parallel()
testCases := []struct {
expectedError error
setupRequest func(*fasthttp.RequestCtx)
name string
extractor extractors.Extractor
}{
{
name: "MissingHeader",
extractor: extractors.FromHeader("X-Missing"),
setupRequest: func(_ *fasthttp.RequestCtx) {
// Don't set the header
},
expectedError: extractors.ErrNotFound,
},
{
name: "MissingForm",
extractor: extractors.FromForm("_missing"),
setupRequest: func(ctx *fasthttp.RequestCtx) {
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
// Don't set form data
},
expectedError: extractors.ErrNotFound,
},
{
name: "MissingQuery",
extractor: extractors.FromQuery("missing"),
setupRequest: func(ctx *fasthttp.RequestCtx) {
ctx.Request.SetRequestURI("/")
// Don't set query param
},
expectedError: extractors.ErrNotFound,
},
{
name: "MissingParam",
extractor: extractors.FromParam("missing"),
setupRequest: func(_ *fasthttp.RequestCtx) {
// This would need special route setup to test properly
// For now, we'll test the extractor directly
},
expectedError: extractors.ErrNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := &fasthttp.RequestCtx{}
tc.setupRequest(ctx)
c := app.AcquireCtx(ctx)
_, err := tc.extractor.Extract(c)
require.Error(t, err)
require.Equal(t, tc.expectedError, err)
app.ReleaseCtx(c)
})
}
}
// Test security warning logs (would need to capture log output in real implementation)
func Test_CSRF_Security_Warnings(t *testing.T) {
t.Parallel()
// Test that insecure extractors trigger warnings
// Note: In a real implementation, you'd want to capture log output
// For now, we just test that the configuration doesn't panic
insecureConfigs := []Config{
{Extractor: extractors.FromQuery("csrf_token")},
{Extractor: extractors.FromParam("csrf")},
}
for i, cfg := range insecureConfigs {
t.Run(fmt.Sprintf("InsecureConfig%d", i), func(t *testing.T) {
t.Parallel()
require.NotPanics(t, func() {
configDefault(cfg)
})
})
}
}
// Test isInsecureCookieExtractor function directly
func Test_isInsecureCookieExtractor(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
cookieName string
extractor extractors.Extractor
expected bool
}{
{
name: "SecureHeaderExtractor",
extractor: extractors.Extractor{
Source: extractors.SourceHeader,
Key: "X-Csrf-Token",
},
cookieName: "csrf_",
expected: false,
},
{
name: "InsecureCookieExtractor",
extractor: extractors.Extractor{
Source: extractors.SourceCookie,
Key: "csrf_",
},
cookieName: "csrf_",
expected: true,
},
{
name: "CookieExtractorDifferentName",
extractor: extractors.Extractor{
Source: extractors.SourceCookie,
Key: "different_cookie",
},
cookieName: "csrf_",
expected: false,
},
{
name: "CustomExtractorSafeName",
extractor: extractors.Extractor{
Source: extractors.SourceCustom,
Key: "safe_key",
},
cookieName: "csrf_",
expected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := isInsecureCookieExtractor(tc.extractor, tc.cookieName)
require.Equal(t, tc.expected, result)
})
}
}
func Test_CSRF_CookieName_CaseInsensitive_Warning(t *testing.T) {
t.Parallel()
// Extractor uses "CSRF_" (uppercase), config uses "csrf_" (lowercase)
extractor := extractors.Extractor{
Extract: func(c fiber.Ctx) (string, error) {
return c.Cookies("CSRF_"), nil
},
Source: extractors.SourceCookie,
Key: "CSRF_",
}
cfg := Config{
CookieName: "csrf_",
Extractor: extractor,
}
// Should not panic, but should log a warning
require.NotPanics(t, func() {
configDefault(cfg)
}, "Should not panic for case-insensitive cookie name match, but should warn")
}
================================================
FILE: middleware/csrf/csrf.go
================================================
package csrf
import (
"errors"
"fmt"
"net/url"
"slices"
"strings"
"time"
"github.com/gofiber/utils/v2"
utilsstrings "github.com/gofiber/utils/v2/strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
)
var (
ErrTokenNotFound = errors.New("csrf: token not found")
ErrTokenInvalid = errors.New("csrf: token invalid")
ErrFetchSiteInvalid = errors.New("csrf: sec-fetch-site header invalid")
ErrRefererNotFound = errors.New("csrf: referer header missing")
ErrRefererInvalid = errors.New("csrf: referer header invalid")
ErrRefererNoMatch = errors.New("csrf: referer does not match host or trusted origins")
ErrOriginInvalid = errors.New("csrf: origin header invalid")
ErrOriginNoMatch = errors.New("csrf: origin does not match host or trusted origins")
errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user
dummyValue = []byte{'+'} // dummyValue is a placeholder value stored in token storage. The actual token validation relies on the key, not this value.
)
// Handler for CSRF middleware
type Handler struct {
sessionManager *sessionManager
storageManager *storageManager
config Config
}
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
// The keys for the values in context
const (
tokenKey contextKey = iota
handlerKey
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
redactKeys := !cfg.DisableValueRedaction
maskValue := func(value string) string {
if redactKeys {
return redactedKey
}
return value
}
// Create manager to simplify storage operations ( see *_manager.go )
var sessionManager *sessionManager
var storageManager *storageManager
if cfg.Session != nil {
sessionManager = newSessionManager(cfg.Session)
} else {
storageManager = newStorageManager(cfg.Storage, redactKeys)
}
// Pre-parse trusted origins
trustedOrigins := []string{}
trustedSubOrigins := []subdomain{}
for _, origin := range cfg.TrustedOrigins {
trimmedOrigin := utils.TrimSpace(origin)
if i := strings.Index(trimmedOrigin, "://*."); i != -1 {
withoutWildcard := trimmedOrigin[:i+len("://")] + trimmedOrigin[i+len("://*."):]
isValid, normalizedOrigin := normalizeOrigin(withoutWildcard)
if !isValid {
panic("[CSRF] Invalid origin format in configuration:" + maskValue(origin))
}
schemeSep := strings.Index(normalizedOrigin, "://") + len("://")
sd := subdomain{prefix: normalizedOrigin[:schemeSep], suffix: normalizedOrigin[schemeSep:]}
trustedSubOrigins = append(trustedSubOrigins, sd)
} else {
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CSRF] Invalid origin format in configuration:" + maskValue(origin))
}
trustedOrigins = append(trustedOrigins, normalizedOrigin)
}
}
// Create the handler outside of the returned function
handler := &Handler{
config: cfg,
sessionManager: sessionManager,
storageManager: storageManager,
}
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Store the CSRF handler in the context
fiber.StoreInContext(c, handlerKey, handler)
var token string
// Action depends on the HTTP method
switch c.Method() {
case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace:
cookieToken := c.Cookies(cfg.CookieName)
if cookieToken != "" {
raw, err := getRawFromStorage(c, cookieToken, &cfg, sessionManager, storageManager)
if err != nil {
return cfg.ErrorHandler(c, err)
}
if raw != nil {
token = cookieToken // Token is valid, safe to set it
}
}
default:
// Assume that anything not defined as 'safe' by RFC7231 needs protection
// Evaluate Sec-Fetch-Site to reject cross-site requests earlier when available.
if err := validateSecFetchSite(c); err != nil {
return cfg.ErrorHandler(c, err)
}
// Enforce an origin check for unsafe requests.
err := originMatchesHost(c, trustedOrigins, trustedSubOrigins)
// If there's no origin, enforce a referer check for HTTPS connections.
if errors.Is(err, errOriginNotFound) {
if c.Scheme() == schemeHTTPS {
err = refererMatchesHost(c, trustedOrigins, trustedSubOrigins)
} else {
// If it's not HTTPS, clear the error to allow the request to proceed.
err = nil
}
}
// If there's an error (either from origin check or referer check), handle it.
if err != nil {
return cfg.ErrorHandler(c, err)
}
// Extract token from client request i.e. header, query, param, form
extractedToken, err := cfg.Extractor.Extract(c)
if err != nil {
if errors.Is(err, extractors.ErrNotFound) {
return cfg.ErrorHandler(c, ErrTokenNotFound)
}
// If there's an error during extraction (other than not found), handle it.
return cfg.ErrorHandler(c, err)
}
if extractedToken == "" {
return cfg.ErrorHandler(c, ErrTokenNotFound)
}
// Double Submit Cookie validation: ensure the extracted token matches the cookie value
// This prevents CSRF attacks by requiring attackers to know both the cookie AND submit
// the same token through a different channel (header, form, etc.)
// WARNING: If using a custom extractor that reads from the same cookie, this provides no protection
if !compareStrings(extractedToken, c.Cookies(cfg.CookieName)) {
return cfg.ErrorHandler(c, ErrTokenInvalid)
}
raw, err := getRawFromStorage(c, extractedToken, &cfg, sessionManager, storageManager)
if err != nil {
return cfg.ErrorHandler(c, err)
}
if raw == nil {
// If token is not in storage, expire the cookie
expireCSRFCookie(c, &cfg)
// and return an error
return cfg.ErrorHandler(c, ErrTokenNotFound)
}
if cfg.SingleUseToken {
// If token is single use, delete it from storage
if err := deleteTokenFromStorage(c, extractedToken, &cfg, sessionManager, storageManager); err != nil {
return cfg.ErrorHandler(c, err)
}
} else {
token = extractedToken // Token is valid, safe to set it
}
}
// Generate CSRF token if not exist
if token == "" {
// And generate a new token
token = cfg.KeyGenerator()
}
// Create or extend the token in the storage
if err := createOrExtendTokenInStorage(c, token, &cfg, sessionManager, storageManager); err != nil {
return cfg.ErrorHandler(c, err)
}
// Update the CSRF cookie
updateCSRFCookie(c, &cfg, token)
// Tell the browser that a new header value is generated
c.Vary(fiber.HeaderCookie)
// Store the token in the context
fiber.StoreInContext(c, tokenKey, token)
// Continue stack
return c.Next()
}
}
// TokenFromContext returns the token found in the context.
// It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context.
// It returns an empty string if the token does not exist.
func TokenFromContext(ctx any) string {
if token, ok := fiber.ValueFromContext[string](ctx, tokenKey); ok {
return token
}
return ""
}
// HandlerFromContext returns the Handler found in the context.
// It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context.
// It returns nil if the handler does not exist.
func HandlerFromContext(ctx any) *Handler {
if handler, ok := fiber.ValueFromContext[*Handler](ctx, handlerKey); ok {
return handler
}
return nil
}
// getRawFromStorage returns the raw value from the storage for the given token
// returns nil if the token does not exist, is expired or is invalid
func getRawFromStorage(c fiber.Ctx, token string, cfg *Config, sessionManager *sessionManager, storageManager *storageManager) ([]byte, error) {
if cfg.Session != nil {
return sessionManager.getRaw(c, token, dummyValue), nil
}
raw, err := storageManager.getRaw(c, token)
if err != nil {
return nil, fmt.Errorf("csrf: failed to fetch token from storage: %w", err)
}
return raw, nil
}
// createOrExtendTokenInStorage creates or extends the token in the storage
func createOrExtendTokenInStorage(c fiber.Ctx, token string, cfg *Config, sessionManager *sessionManager, storageManager *storageManager) error {
if cfg.Session != nil {
sessionManager.setRaw(c, token, dummyValue, cfg.IdleTimeout)
return nil
}
if err := storageManager.setRaw(c, token, dummyValue, cfg.IdleTimeout); err != nil {
return fmt.Errorf("csrf: failed to store token in storage: %w", err)
}
return nil
}
func deleteTokenFromStorage(c fiber.Ctx, token string, cfg *Config, sessionManager *sessionManager, storageManager *storageManager) error {
if cfg.Session != nil {
sessionManager.delRaw(c)
return nil
}
if err := storageManager.delRaw(c, token); err != nil {
return fmt.Errorf("csrf: failed to delete token from storage: %w", err)
}
return nil
}
// Update CSRF cookie
// if expireCookie is true, the cookie will expire immediately
func updateCSRFCookie(c fiber.Ctx, cfg *Config, token string) {
setCSRFCookie(c, cfg, token, cfg.IdleTimeout)
}
func expireCSRFCookie(c fiber.Ctx, cfg *Config) {
setCSRFCookie(c, cfg, "", -time.Hour)
}
func setCSRFCookie(c fiber.Ctx, cfg *Config, token string, expiry time.Duration) {
cookie := &fiber.Cookie{
Name: cfg.CookieName,
Value: token,
Domain: cfg.CookieDomain,
Path: cfg.CookiePath,
Secure: cfg.CookieSecure,
HTTPOnly: cfg.CookieHTTPOnly,
SameSite: cfg.CookieSameSite,
SessionOnly: cfg.CookieSessionOnly,
Expires: time.Now().Add(expiry),
}
// Set the CSRF cookie to the response
c.Cookie(cookie)
}
// DeleteToken removes the token found in the context from the storage
// and expires the CSRF cookie
func (handler *Handler) DeleteToken(c fiber.Ctx) error {
// Extract token from the client request cookie
cookieToken := c.Cookies(handler.config.CookieName)
if cookieToken == "" {
return handler.config.ErrorHandler(c, ErrTokenNotFound)
}
// Remove the token from storage
if err := deleteTokenFromStorage(c, cookieToken, &handler.config, handler.sessionManager, handler.storageManager); err != nil {
return handler.config.ErrorHandler(c, err)
}
// Expire the cookie
expireCSRFCookie(c, &handler.config)
return nil
}
func validateSecFetchSite(c fiber.Ctx) error {
secFetchSite := utils.Trim(c.Get(fiber.HeaderSecFetchSite), ' ')
if secFetchSite == "" {
return nil
}
switch utilsstrings.ToLower(secFetchSite) {
case "same-origin", "none", "cross-site", "same-site":
return nil
default:
return ErrFetchSiteInvalid
}
}
// originMatchesHost checks that the origin header matches the host header
// returns an error if the origin header is not present or is invalid
// returns nil if the origin header is valid
func originMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error {
origin := utilsstrings.ToLower(c.Get(fiber.HeaderOrigin))
if origin == "" || origin == "null" { // "null" is set by some browsers when the origin is a secure context https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin#description
return errOriginNotFound
}
originURL, err := url.Parse(origin)
if err != nil {
return ErrOriginInvalid
}
if schemeAndHostMatch(originURL.Scheme, originURL.Host, c.Scheme(), c.Host()) {
return nil
}
if slices.Contains(trustedOrigins, origin) {
return nil
}
for _, trustedSubOrigin := range trustedSubOrigins {
if trustedSubOrigin.match(origin) {
return nil
}
}
return ErrOriginNoMatch
}
// refererMatchesHost checks that the referer header matches the host header
// returns an error if the referer header is not present or is invalid
// returns nil if the referer header is valid
func refererMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error {
referer := utilsstrings.ToLower(c.Get(fiber.HeaderReferer))
if referer == "" {
return ErrRefererNotFound
}
refererURL, err := url.Parse(referer)
if err != nil {
return ErrRefererInvalid
}
if schemeAndHostMatch(refererURL.Scheme, refererURL.Host, c.Scheme(), c.Host()) {
return nil
}
referer = refererURL.String()
if slices.Contains(trustedOrigins, referer) {
return nil
}
for _, trustedSubOrigin := range trustedSubOrigins {
if trustedSubOrigin.match(referer) {
return nil
}
}
return ErrRefererNoMatch
}
================================================
FILE: middleware/csrf/csrf_test.go
================================================
package csrf
import (
"context"
"errors"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
"github.com/gofiber/fiber/v3/middleware/session"
"github.com/gofiber/utils/v2"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
type failingCSRFStorage struct {
data map[string][]byte
errs map[string]error
}
func newFailingCSRFStorage() *failingCSRFStorage {
return &failingCSRFStorage{
data: make(map[string][]byte),
errs: make(map[string]error),
}
}
func (s *failingCSRFStorage) GetWithContext(_ context.Context, key string) ([]byte, error) {
if err, ok := s.errs["get|"+key]; ok && err != nil {
return nil, err
}
if val, ok := s.data[key]; ok {
return append([]byte(nil), val...), nil
}
return nil, nil
}
var trustedProxyConfig = fiber.Config{
TrustProxy: true,
TrustProxyConfig: fiber.TrustProxyConfig{
Proxies: []string{"0.0.0.0"},
},
}
func newTrustedApp() *fiber.App {
return fiber.New(trustedProxyConfig)
}
func newTrustedRequestCtx() *fasthttp.RequestCtx {
ctx := &fasthttp.RequestCtx{}
ctx.SetRemoteAddr(net.Addr(&net.TCPAddr{IP: net.ParseIP("0.0.0.0")}))
return ctx
}
func (s *failingCSRFStorage) Get(key string) ([]byte, error) {
return s.GetWithContext(context.Background(), key)
}
func (s *failingCSRFStorage) SetWithContext(_ context.Context, key string, val []byte, _ time.Duration) error {
if err, ok := s.errs["set|"+key]; ok && err != nil {
return err
}
s.data[key] = append([]byte(nil), val...)
return nil
}
func (s *failingCSRFStorage) Set(key string, val []byte, exp time.Duration) error {
return s.SetWithContext(context.Background(), key, val, exp)
}
func (s *failingCSRFStorage) DeleteWithContext(_ context.Context, key string) error {
if err, ok := s.errs["del|"+key]; ok && err != nil {
return err
}
delete(s.data, key)
return nil
}
func (s *failingCSRFStorage) Delete(key string) error {
return s.DeleteWithContext(context.Background(), key)
}
func (s *failingCSRFStorage) ResetWithContext(context.Context) error {
s.data = make(map[string][]byte)
s.errs = make(map[string]error)
return nil
}
func (s *failingCSRFStorage) Reset() error {
return s.ResetWithContext(context.Background())
}
func (*failingCSRFStorage) Close() error { return nil }
func TestCSRFStorageGetError(t *testing.T) {
t.Parallel()
storage := newFailingCSRFStorage()
storage.errs["get|token"] = errors.New("boom")
var captured error
app := fiber.New()
app.Use(New(Config{
Storage: storage,
ErrorHandler: func(_ fiber.Ctx, err error) error {
captured = err
return fiber.ErrTeapot
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.AddCookie(&http.Cookie{Name: ConfigDefault.CookieName, Value: "token"})
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
require.Error(t, captured)
require.ErrorContains(t, captured, "csrf: failed to fetch token from storage")
}
func TestCSRFStorageSetError(t *testing.T) {
t.Parallel()
storage := newFailingCSRFStorage()
storage.errs["set|token"] = errors.New("boom")
var captured error
app := fiber.New()
app.Use(New(Config{
Storage: storage,
KeyGenerator: func() string {
return "token"
},
ErrorHandler: func(_ fiber.Ctx, err error) error {
captured = err
return fiber.ErrTeapot
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
require.Error(t, captured)
require.ErrorContains(t, captured, "csrf: failed to store token in storage")
}
func TestCSRFStorageDeleteError(t *testing.T) {
t.Parallel()
storage := newFailingCSRFStorage()
storage.data["token"] = []byte("value")
storage.errs["del|token"] = errors.New("boom")
var captured error
app := fiber.New()
app.Use(New(Config{
Storage: storage,
SingleUseToken: true,
ErrorHandler: func(_ fiber.Ctx, err error) error {
captured = err
return fiber.ErrTeapot
},
}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody)
req.Header.Set(HeaderName, "token")
req.AddCookie(&http.Cookie{Name: ConfigDefault.CookieName, Value: "token"})
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
require.Error(t, captured)
require.ErrorContains(t, captured, "csrf: failed to delete token from storage")
}
func Test_CSRF(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace}
for _, method := range methods {
// Generate CSRF token
ctx.Request.Header.SetMethod(method)
h(ctx)
// Without CSRF cookie
ctx.Request.Header.Reset()
ctx.Request.ResetBody()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Invalid CSRF token
ctx.Request.Header.Reset()
ctx.Request.ResetBody()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, "johndoe")
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Valid CSRF token
ctx.Request.Header.Reset()
ctx.Request.ResetBody()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(method)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
}
}
func Test_CSRF_WithSession(t *testing.T) {
t.Parallel()
// session store
store := session.NewStore(session.Config{
Extractor: extractors.FromCookie("_session"),
})
// fiber instance
app := fiber.New()
// fiber context
ctx := &fasthttp.RequestCtx{}
defer app.ReleaseCtx(app.AcquireCtx(ctx))
// get session
sess, err := store.Get(app.AcquireCtx(ctx))
require.NoError(t, err)
require.True(t, sess.Fresh())
// the session string is no longer be 123
newSessionIDString := sess.ID()
require.NoError(t, sess.Save())
app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString)
// middleware config
config := Config{
Session: store,
}
// middleware
app.Use(New(config))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace}
for _, method := range methods {
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("_session", newSessionIDString)
h(ctx)
// Without CSRF cookie
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie("_session", newSessionIDString)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Empty/invalid CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, "johndoe")
ctx.Request.Header.SetCookie("_session", newSessionIDString)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Valid CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(method)
ctx.Request.Header.SetCookie("_session", newSessionIDString)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
for header := range strings.SplitSeq(token, ";") {
if strings.Split(utils.TrimSpace(header), "=")[0] == ConfigDefault.CookieName {
token = strings.Split(header, "=")[1]
break
}
}
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie("_session", newSessionIDString)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
}
}
// go test -run Test_CSRF_WithSession_Middleware
func Test_CSRF_WithSession_Middleware(t *testing.T) {
t.Parallel()
app := fiber.New()
// session mw
smh, sstore := session.NewWithStore()
// csrf mw
cmh := New(Config{
Session: sstore,
})
app.Use(smh)
app.Use(cmh)
app.Get("/", func(c fiber.Ctx) error {
sess := session.FromContext(c)
sess.Set("hello", "world")
return c.SendStatus(fiber.StatusOK)
})
app.Post("/", func(c fiber.Ctx) error {
sess := session.FromContext(c)
if sess.Get("hello") != "world" {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token and session_id
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
csrfCookie := fasthttp.AcquireCookie()
csrfCookie.SetKey(ConfigDefault.CookieName)
require.True(t, ctx.Response.Header.Cookie(csrfCookie))
csrfToken := string(csrfCookie.Value())
require.NotEmpty(t, csrfToken)
fasthttp.ReleaseCookie(csrfCookie)
sessionCookie := fasthttp.AcquireCookie()
sessionCookie.SetKey("session_id")
require.True(t, ctx.Response.Header.Cookie(sessionCookie))
sessionID := string(sessionCookie.Value())
require.NotEmpty(t, sessionID)
fasthttp.ReleaseCookie(sessionCookie)
// Use the CSRF token and session_id
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, csrfToken)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, csrfToken)
ctx.Request.Header.SetCookie("session_id", sessionID)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
}
// go test -run Test_CSRF_ExpiredToken
func Test_CSRF_ExpiredToken(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
IdleTimeout: 1 * time.Second,
}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Use the CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Wait for the token to expire
time.Sleep(1250 * time.Millisecond)
// Expired CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
// go test -run Test_CSRF_ExpiredToken_WithSession
func Test_CSRF_ExpiredToken_WithSession(t *testing.T) {
t.Parallel()
// session store
store := session.NewStore(session.Config{
Extractor: extractors.FromCookie("_session"),
})
// fiber instance
app := fiber.New()
// fiber context
ctx := &fasthttp.RequestCtx{}
defer app.ReleaseCtx(app.AcquireCtx(ctx))
// get session
sess, err := store.Get(app.AcquireCtx(ctx))
require.NoError(t, err)
require.True(t, sess.Fresh())
// get session id
newSessionIDString := sess.ID()
require.NoError(t, sess.Save())
app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString)
// middleware config
config := Config{
Session: store,
IdleTimeout: 1 * time.Second,
}
// middleware
app.Use(New(config))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("_session", newSessionIDString)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
for header := range strings.SplitSeq(token, ";") {
if strings.Split(utils.TrimSpace(header), "=")[0] == ConfigDefault.CookieName {
token = strings.Split(header, "=")[1]
break
}
}
// Use the CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie("_session", newSessionIDString)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Wait for the token to expire
time.Sleep(1*time.Second + 100*time.Millisecond)
// Expired CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie("_session", newSessionIDString)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
// go test -run Test_CSRF_MultiUseToken
func Test_CSRF_MultiUseToken(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Extractor: extractors.FromHeader("X-Csrf-Token"),
}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set("X-Csrf-Token", "johndoe")
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set("X-Csrf-Token", token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
newToken = strings.Split(strings.Split(newToken, ";")[0], "=")[1]
require.Equal(t, 200, ctx.Response.StatusCode())
// Check if the token is not a dummy value
require.Equal(t, token, newToken)
}
// go test -run Test_CSRF_SingleUseToken
func Test_CSRF_SingleUseToken(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
SingleUseToken: true,
}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Use the CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
newToken = strings.Split(strings.Split(newToken, ";")[0], "=")[1]
if token == newToken {
t.Error("new token should not be the same as the old token")
}
// Use the CSRF token again
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
// go test -run Test_CSRF_Next
func Test_CSRF_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_CSRF_From_Form(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Extractor: extractors.FromForm("_csrf")}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
ctx.Request.SetBodyString("_csrf=" + token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
}
func Test_CSRF_From_Query(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Extractor: extractors.FromQuery("_csrf")}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/?_csrf=" + utils.UUIDv4())
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.SetRequestURI("/?_csrf=" + token)
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
require.Equal(t, "OK", string(ctx.Response.Body()))
}
func Test_CSRF_From_Param(t *testing.T) {
t.Parallel()
app := fiber.New()
csrfGroup := app.Group("/:csrf", New(Config{Extractor: extractors.FromParam("csrf")}))
csrfGroup.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/" + utils.UUIDv4())
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/" + utils.UUIDv4())
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.SetRequestURI("/" + token)
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
require.Equal(t, "OK", string(ctx.Response.Body()))
}
func Test_CSRF_From_Custom(t *testing.T) {
t.Parallel()
app := fiber.New()
extractor := extractors.Extractor{
Extract: func(c fiber.Ctx) (string, error) {
body := string(c.Body())
// Generate the correct extractor to get the token from the correct location
selectors := strings.Split(body, "=")
if len(selectors) != 2 || selectors[1] == "" {
return "", extractors.ErrNotFound
}
return selectors[1], nil
},
Source: extractors.SourceCustom,
Key: "_csrf",
}
app.Use(New(Config{Extractor: extractor}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Invalid CSRF token
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
ctx.Request.SetBodyString("_csrf=" + token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
}
func Test_CSRF_Extractor_EmptyString(t *testing.T) {
t.Parallel()
app := fiber.New()
extractor := extractors.Extractor{
Extract: func(_ fiber.Ctx) (string, error) {
return "", nil
},
Source: extractors.SourceCustom,
Key: "_csrf",
}
errorHandler := func(c fiber.Ctx, err error) error {
return c.Status(403).SendString(err.Error())
}
app.Use(New(Config{
Extractor: extractor,
ErrorHandler: errorHandler,
}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
ctx.Request.SetBodyString("_csrf=" + token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
require.Equal(t, ErrTokenNotFound.Error(), string(ctx.Response.Body()))
}
func Test_CSRF_SecFetchSite(t *testing.T) {
t.Parallel()
errorHandler := func(c fiber.Ctx, err error) error {
return c.Status(fiber.StatusForbidden).SendString(err.Error())
}
app := newTrustedApp()
app.Use(New(Config{ErrorHandler: errorHandler}))
app.All("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := newTrustedRequestCtx()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetHost("example.com")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
tests := []struct {
name string
method string
secFetchSite string
origin string
expectedStatus int16
https bool
expectFetchSiteInvalid bool
}{
{
name: "same-origin allowed",
method: fiber.MethodPost,
secFetchSite: "same-origin",
origin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "none allowed",
method: fiber.MethodPost,
secFetchSite: "none",
origin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "cross-site with origin allowed",
method: fiber.MethodPost,
secFetchSite: "cross-site",
origin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "same-site with origin allowed",
method: fiber.MethodPost,
secFetchSite: "same-site",
origin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "cross-site with mismatched origin blocked",
method: fiber.MethodPost,
secFetchSite: "cross-site",
origin: "https://attacker.example",
expectedStatus: http.StatusForbidden,
},
{
name: "same-site with null origin blocked",
method: fiber.MethodPost,
secFetchSite: "same-site",
origin: "null",
expectedStatus: http.StatusForbidden,
https: true,
},
{
name: "invalid header blocked",
method: fiber.MethodPost,
secFetchSite: "weird",
origin: "http://example.com",
expectedStatus: http.StatusForbidden,
expectFetchSiteInvalid: true,
},
{
name: "no header with no origin",
method: fiber.MethodPost,
origin: "",
expectedStatus: http.StatusOK,
},
{
name: "no header with matching origin",
method: fiber.MethodPost,
origin: "http://example.com",
expectedStatus: http.StatusOK,
},
{
name: "no header with mismatched origin",
method: fiber.MethodPost,
origin: "https://attacker.example",
expectedStatus: http.StatusForbidden,
},
{
name: "no header with null origin",
method: fiber.MethodPost,
origin: "null",
expectedStatus: http.StatusForbidden,
https: true,
},
{
name: "GET allowed",
method: fiber.MethodGet,
secFetchSite: "cross-site",
expectedStatus: http.StatusOK,
},
{
name: "HEAD allowed",
method: fiber.MethodHead,
secFetchSite: "cross-site",
expectedStatus: http.StatusOK,
},
{
name: "OPTIONS allowed",
method: fiber.MethodOptions,
secFetchSite: "cross-site",
expectedStatus: http.StatusOK,
},
{
name: "PUT with mismatched origin blocked",
method: fiber.MethodPut,
secFetchSite: "cross-site",
origin: "https://attacker.example",
expectedStatus: http.StatusForbidden,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
c := &fasthttp.RequestCtx{}
scheme := "http"
if tt.https {
scheme = "https"
}
c.Request.Header.SetMethod(tt.method)
c.Request.URI().SetScheme(scheme)
c.Request.URI().SetHost("example.com")
c.Request.Header.SetHost("example.com")
c.Request.Header.SetProtocol(scheme)
if scheme == "https" {
c.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
}
if tt.origin != "" {
c.Request.Header.Set(fiber.HeaderOrigin, tt.origin)
}
if tt.secFetchSite != "" {
c.Request.Header.Set(fiber.HeaderSecFetchSite, tt.secFetchSite)
}
safe := tt.method == fiber.MethodGet || tt.method == fiber.MethodHead || tt.method == fiber.MethodOptions || tt.method == fiber.MethodTrace
if !safe {
c.Request.Header.Set(HeaderName, token)
c.Request.Header.SetCookie(ConfigDefault.CookieName, token)
}
h(c)
require.Equal(t, int(tt.expectedStatus), c.Response.StatusCode())
if tt.expectFetchSiteInvalid {
require.Equal(t, ErrFetchSiteInvalid.Error(), string(c.Response.Body()))
}
})
}
}
func Test_CSRF_Origin(t *testing.T) {
t.Parallel()
app := newTrustedApp()
app.Use(New(Config{CookieSecure: true}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := newTrustedRequestCtx()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test Correct Origin with port
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com:8080")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com:8080")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com:8080")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Origin without default HTTP port against host with default port
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com:80")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com:80")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Origin with default HTTP port against host without port
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com:80")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Correct Origin with wrong port
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com:3000")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Test Correct Origin with null
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "null")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Correct Origin with ReverseProxy
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("10.0.1.42.com:8080")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("10.0.1.42:8080")
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http")
ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com")
ctx.Request.Header.Set(fiber.HeaderXForwardedFor, `192.0.2.43, "[2001:db8:cafe::17]"`)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Origin without default HTTPS port against host with default port
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("example.com:443")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.Header.SetHost("example.com:443")
ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Origin with default HTTPS port against host without port
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com:443")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Correct Origin with ReverseProxy Missing X-Forwarded-* Headers
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("10.0.1.42:8080")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("10.0.1.42:8080")
ctx.Request.Header.Set(fiber.HeaderXUrlScheme, "http") // We need to set this header to make sure c.Protocol() returns http
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Test Wrong Origin
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http")
ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://csrf.example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
func Test_CSRF_TrustedOrigins(t *testing.T) {
t.Parallel()
app := newTrustedApp()
app.Use(New(Config{
CookieSecure: true,
TrustedOrigins: []string{
"http://safe.example.com",
"https://safe.example.com",
"http://*.domain-1.com",
"https://*.domain-1.com",
},
}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := newTrustedRequestCtx()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test Trusted Origin
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://safe.example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Trusted Origin Subdomain
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("domain-1.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("domain-1.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://safe.domain-1.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Trusted Origin deeply nested subdomain
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("a.b.c.domain-1.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("a.b.c.domain-1.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "https://a.b.c.domain-1.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Trusted Origin Invalid
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("domain-1.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("domain-1.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://evildomain-1.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Test Trusted Origin malformed subdomain
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("domain-1.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("domain-1.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://evil.comdomain-1.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Test Trusted Referer
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "https://safe.example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Trusted Referer Wildcard
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("domain-1.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("domain-1.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "https://safe.domain-1.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Trusted Referer deeply nested subdomain
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("a.b.c.domain-1.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("a.b.c.domain-1.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "https://a.b.c.domain-1.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Trusted Referer Invalid
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("api.domain-1.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("api.domain-1.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "https://evildomain-1.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
func Test_CSRF_TrustedOrigins_InvalidOrigins(t *testing.T) {
t.Parallel()
tests := []struct {
name string
origin string
}{
{name: "No Scheme", origin: "localhost"},
{name: "Wildcard", origin: "https://*"},
{name: "Wildcard domain", origin: "https://*example.com"},
{name: "File Scheme", origin: "file://example.com"},
{name: "FTP Scheme", origin: "ftp://example.com"},
{name: "Port Wildcard", origin: "http://example.com:*"},
{name: "Multiple Wildcards", origin: "https://*.*.com"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
origin := tt.origin
t.Parallel()
require.Panics(t, func() {
app := fiber.New()
app.Use(New(Config{
CookieSecure: true,
TrustedOrigins: []string{origin},
}))
}, "Expected panic")
})
}
}
func Test_CSRF_Referer(t *testing.T) {
t.Parallel()
app := newTrustedApp()
app.Use(New(Config{CookieSecure: true}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := newTrustedRequestCtx()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test Correct Referer with port
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("example.com:8443")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("example.com:8443")
ctx.Request.Header.Set(fiber.HeaderReferer, ctx.Request.URI().String())
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Referer without default HTTPS port against host with default port
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("example.com:443")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("example.com:443")
ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Referer with default HTTPS port against host without port
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com:443")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Correct Referer with ReverseProxy
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("10.0.1.42.com:8443")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("10.0.1.42:8443")
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com")
ctx.Request.Header.Set(fiber.HeaderXForwardedFor, `192.0.2.43, "[2001:db8:cafe::17]"`)
ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Referer without default HTTP port against host with default port
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http")
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com:80")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com:80")
ctx.Request.Header.Set(fiber.HeaderReferer, "http://example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Referer with default HTTP port against host without port
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http")
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "http://example.com:80")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Correct Referer with ReverseProxy Missing X-Forwarded-* Headers
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("10.0.1.42:8443")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("10.0.1.42:8443")
ctx.Request.Header.Set(fiber.HeaderXUrlScheme, "https") // We need to set this header to make sure c.Protocol() returns https
ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Test Correct Referer with path
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com/action/items?gogogo=true")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Wrong Referer
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "https://csrf.example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
func Test_CSRF_DeleteToken(t *testing.T) {
t.Parallel()
app := fiber.New()
config := ConfigDefault
app.Use(New(config))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// DeleteToken after token generation and remove the cookie
ctx.Request.Header.Reset()
ctx.Request.ResetBody()
ctx.Response.Reset()
ctx.Request.Header.Set(HeaderName, "")
handler := HandlerFromContext(app.AcquireCtx(ctx))
if handler != nil {
ctx.Request.Header.DelAllCookies()
err := handler.DeleteToken(app.AcquireCtx(ctx))
require.ErrorIs(t, err, ErrTokenNotFound)
}
h(ctx)
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Delete the CSRF token
ctx.Request.Header.Reset()
ctx.Request.ResetBody()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
handler = HandlerFromContext(app.AcquireCtx(ctx))
if handler != nil {
if err := handler.DeleteToken(app.AcquireCtx(ctx)); err != nil {
t.Fatal(err)
}
}
h(ctx)
ctx.Request.Header.Reset()
ctx.Request.ResetBody()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
func Test_CSRF_DeleteToken_WithSession(t *testing.T) {
t.Parallel()
// session store
store := session.NewStore(session.Config{
Extractor: extractors.FromCookie("_session"),
})
// fiber instance
app := fiber.New()
// fiber context
ctx := &fasthttp.RequestCtx{}
// get session
sess, err := store.Get(app.AcquireCtx(ctx))
require.NoError(t, err)
require.True(t, sess.Fresh())
// the session string is no longer be 123
newSessionIDString := sess.ID()
require.NoError(t, sess.Save())
app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString)
// middleware config
config := Config{
Session: store,
}
// middleware
app.Use(New(config))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("_session", newSessionIDString)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Delete the CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
handler := HandlerFromContext(app.AcquireCtx(ctx))
if handler != nil {
if err := handler.DeleteToken(app.AcquireCtx(ctx)); err != nil {
t.Fatal(err)
}
}
h(ctx)
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
ctx.Request.Header.SetCookie("_session", newSessionIDString)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
t.Parallel()
app := fiber.New()
errHandler := func(ctx fiber.Ctx, err error) error {
require.Equal(t, ErrTokenInvalid, err)
return ctx.Status(419).Send([]byte("invalid CSRF token"))
}
app.Use(New(Config{ErrorHandler: errHandler}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
// invalid CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, "johndoe")
h(ctx)
require.Equal(t, 419, ctx.Response.StatusCode())
require.Equal(t, "invalid CSRF token", string(ctx.Response.Body()))
}
func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
t.Parallel()
app := fiber.New()
errHandler := func(ctx fiber.Ctx, err error) error {
require.Equal(t, ErrTokenNotFound, err)
return ctx.Status(419).Send([]byte("empty CSRF token"))
}
app.Use(New(Config{ErrorHandler: errHandler}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
// empty CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
require.Equal(t, 419, ctx.Response.StatusCode())
require.Equal(t, "empty CSRF token", string(ctx.Response.Body()))
}
func Test_CSRF_ErrorHandler_MissingReferer(t *testing.T) {
t.Parallel()
app := newTrustedApp()
errHandler := func(ctx fiber.Ctx, err error) error {
require.Equal(t, ErrRefererNotFound, err)
return ctx.Status(419).Send([]byte("empty CSRF token"))
}
app.Use(New(Config{
CookieSecure: true,
ErrorHandler: errHandler,
}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := newTrustedRequestCtx()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 419, ctx.Response.StatusCode())
}
func Test_CSRF_Cookie_Injection_Exploit(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Inject CSRF token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf_=pwned;")
ctx.Request.SetRequestURI("/")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Exploit CSRF token we just injected
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.Set(fiber.HeaderCookie, "csrf_=pwned;")
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode(), "CSRF exploit successful")
}
// Test_CSRF_UnsafeHeaderValue ensures that unsafe header values, such as those described in https://github.com/gofiber/fiber/issues/2045, are rejected and the bug remains fixed.
// go test -race -run Test_CSRF_UnsafeHeaderValue
func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
app.Get("/test", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
var token string
for _, c := range resp.Cookies() {
if c.Name != ConfigDefault.CookieName {
continue
}
token = c.Value
break
}
t.Log("token", token)
getReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
getReq.Header.Set(HeaderName, token)
resp, err = app.Test(getReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
getReq = httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)
getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
getReq.Header.Set(fiber.HeaderCacheControl, "no")
getReq.Header.Set(HeaderName, token)
getReq.AddCookie(&http.Cookie{
Name: ConfigDefault.CookieName,
Value: token,
})
resp, err = app.Test(getReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
getReq.Header.Set(fiber.HeaderAccept, "*/*")
getReq.Header.Del(HeaderName)
resp, err = app.Test(getReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
postReq := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody)
postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
postReq.Header.Set(HeaderName, token)
postReq.AddCookie(&http.Cookie{
Name: ConfigDefault.CookieName,
Value: token,
})
resp, err = app.Test(postReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4
func Benchmark_Middleware_CSRF_Check(b *testing.B) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test Correct Referer POST
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
b.ReportAllocs()
for b.Loop() {
h(ctx)
}
require.Equal(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode())
}
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4
func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
b.ReportAllocs()
for b.Loop() {
h(ctx)
}
// Ensure the GET request returns a 418 status code
require.Equal(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode())
}
func Test_CSRF_InvalidURLHeaders(t *testing.T) {
t.Parallel()
app := newTrustedApp()
errHandler := func(ctx fiber.Ctx, err error) error {
return ctx.Status(419).Send([]byte(err.Error()))
}
app.Use(New(Config{ErrorHandler: errHandler}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := newTrustedRequestCtx()
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// invalid Origin
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("http")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("http")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://[::1]:%38%30/Invalid Origin")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 419, ctx.Response.StatusCode())
require.Equal(t, ErrOriginInvalid.Error(), string(ctx.Response.Body()))
// invalid Referer
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "http://[::1]:%38%30/Invalid Referer")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 419, ctx.Response.StatusCode())
require.Equal(t, ErrRefererInvalid.Error(), string(ctx.Response.Body()))
}
func Test_CSRF_TokenFromContext(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
token := TokenFromContext(c)
require.NotEmpty(t, token)
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
func Test_CSRF_FromContextMethods(t *testing.T) {
t.Parallel()
app := fiber.New(fiber.Config{PassLocalsToContext: true})
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
token := TokenFromContext(c)
require.NotEmpty(t, token)
handler := HandlerFromContext(c)
require.NotNil(t, handler)
customCtx, ok := c.(fiber.CustomCtx)
require.True(t, ok)
require.Equal(t, token, TokenFromContext(customCtx))
require.Equal(t, handler, HandlerFromContext(customCtx))
require.Equal(t, token, TokenFromContext(c.RequestCtx()))
require.Equal(t, token, TokenFromContext(c.Context()))
require.Equal(t, handler, HandlerFromContext(c.RequestCtx()))
require.Equal(t, handler, HandlerFromContext(c.Context()))
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
func Test_CSRF_FromContextMethods_Invalid(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
token := TokenFromContext(c)
require.Empty(t, token)
handler := HandlerFromContext(c)
require.Nil(t, handler)
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
func Test_deleteTokenFromStorage(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(func() { app.ReleaseCtx(ctx) })
token := "token123"
dummy := []byte("dummy")
store := session.NewStore()
sm := newSessionManager(store)
stm := newStorageManager(nil, true)
sm.setRaw(ctx, token, dummy, time.Minute)
cfg := Config{Session: store}
require.NoError(t, deleteTokenFromStorage(ctx, token, &cfg, sm, stm))
raw := sm.getRaw(ctx, token, dummy)
require.Nil(t, raw)
sm2 := newSessionManager(nil)
stm2 := newStorageManager(nil, true)
require.NoError(t, stm2.setRaw(context.Background(), token, dummy, time.Minute))
cfg = Config{}
require.NoError(t, deleteTokenFromStorage(ctx, token, &cfg, sm2, stm2))
raw, err := stm2.getRaw(context.Background(), token)
require.NoError(t, err)
require.Nil(t, raw)
}
func Test_storageManager_logKey(t *testing.T) {
t.Parallel()
redacted := newStorageManager(nil, true)
require.Equal(t, redactedKey, redacted.logKey("secret"))
plain := newStorageManager(nil, false)
require.Equal(t, "secret", plain.logKey("secret"))
}
func Test_CSRF_Chain_Extractor(t *testing.T) {
t.Parallel()
app := fiber.New()
// Chain extractor: try header first, fall back to form
chainExtractor := extractors.Chain(
extractors.FromHeader("X-Csrf-Token"),
extractors.FromForm("_csrf"),
)
app.Use(New(Config{Extractor: chainExtractor}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test 1: Token in header (first extractor should succeed)
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set("X-Csrf-Token", token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test 2: Token in form (fallback should succeed)
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
ctx.Request.SetBodyString("_csrf=" + token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test 3: Token in both header and form (header should take precedence)
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
ctx.Request.Header.Set("X-Csrf-Token", token)
ctx.Request.SetBodyString("_csrf=wrong_token")
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test 4: No token in either location
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Test 5: Wrong token in both locations
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
ctx.Request.Header.Set("X-Csrf-Token", "wrong_token")
ctx.Request.SetBodyString("_csrf=also_wrong")
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
func Test_CSRF_Chain_Extractor_Empty(t *testing.T) {
t.Parallel()
app := fiber.New()
// Empty chain extractor
emptyChain := extractors.Chain()
app.Use(New(Config{Extractor: emptyChain}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test with empty chain - should always fail
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set("X-Csrf-Token", token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
func Test_CSRF_Chain_Extractor_SingleExtractor(t *testing.T) {
t.Parallel()
app := fiber.New()
// Chain with single extractor (should behave like the single extractor)
singleChain := extractors.Chain(extractors.FromHeader("X-Csrf-Token"))
app.Use(New(Config{Extractor: singleChain}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test valid token in header
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set("X-Csrf-Token", token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test no token
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
}
func Test_CSRF_All_Extractors(t *testing.T) {
t.Parallel()
testCases := []struct {
setupRequest func(ctx *fasthttp.RequestCtx, token string)
name string
extractor extractors.Extractor
expectStatus int
}{
{
name: "FromHeader",
extractor: extractors.FromHeader("X-Csrf-Token"),
setupRequest: func(ctx *fasthttp.RequestCtx, token string) {
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set("X-Csrf-Token", token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
},
expectStatus: 200,
},
{
name: "FromHeader_Missing",
extractor: extractors.FromHeader("X-Csrf-Token"),
setupRequest: func(ctx *fasthttp.RequestCtx, token string) {
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
},
expectStatus: 403,
},
{
name: "FromForm",
extractor: extractors.FromForm("_csrf"),
setupRequest: func(ctx *fasthttp.RequestCtx, token string) {
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
ctx.Request.SetBodyString("_csrf=" + token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
},
expectStatus: 200,
},
{
name: "FromForm_Missing",
extractor: extractors.FromForm("_csrf"),
setupRequest: func(ctx *fasthttp.RequestCtx, token string) {
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
},
expectStatus: 403,
},
{
name: "FromQuery",
extractor: extractors.FromQuery("csrf_token"),
setupRequest: func(ctx *fasthttp.RequestCtx, token string) {
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/?csrf_token=" + token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
},
expectStatus: 200,
},
{
name: "FromQuery_Missing",
extractor: extractors.FromQuery("csrf_token"),
setupRequest: func(ctx *fasthttp.RequestCtx, token string) {
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
},
expectStatus: 403,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{Extractor: tc.extractor}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test the extractor
ctx.Request.Reset()
ctx.Response.Reset()
tc.setupRequest(ctx, token)
h(ctx)
require.Equal(t, tc.expectStatus, ctx.Response.StatusCode(),
"Test case %s failed: expected %d, got %d", tc.name, tc.expectStatus, ctx.Response.StatusCode())
})
}
}
func Test_CSRF_Param_Extractor(t *testing.T) {
t.Parallel()
testCases := []struct {
setupRequest func(ctx *fasthttp.RequestCtx, token string)
name string
expectStatus int
}{
{
name: "FromParam_Valid",
setupRequest: func(ctx *fasthttp.RequestCtx, token string) {
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/" + token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
},
expectStatus: 200,
},
{
name: "FromParam_Invalid",
setupRequest: func(ctx *fasthttp.RequestCtx, token string) {
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/wrong_token")
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
},
expectStatus: 403,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
// Only use param-based routing for param extractor tests
csrfGroup := app.Group("/:csrf", New(Config{Extractor: extractors.FromParam("csrf")}))
csrfGroup.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/" + utils.UUIDv4())
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test the extractor
ctx.Request.Reset()
ctx.Response.Reset()
tc.setupRequest(ctx, token)
h(ctx)
require.Equal(t, tc.expectStatus, ctx.Response.StatusCode(),
"Test case %s failed: expected %d, got %d", tc.name, tc.expectStatus, ctx.Response.StatusCode())
})
}
}
func Test_CSRF_Param_Extractor_Missing(t *testing.T) {
t.Parallel()
// Test the case where no param is provided (should get 403 from CSRF middleware on the catch-all route)
app := fiber.New()
// Add a catch-all route with CSRF middleware for missing param case
app.Use(New(Config{Extractor: extractors.FromParam("csrf")}))
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/")
h(ctx)
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test missing param (accessing "/" instead of "/:csrf")
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode(), "Missing param should return 403")
}
func Test_CSRF_Extractors_ErrorTypes(t *testing.T) {
t.Parallel()
// Test all extractor error types
testCases := []struct {
expected error
setupCtx func(ctx *fasthttp.RequestCtx) // Add setup function
name string
extractor extractors.Extractor
}{
{
name: "Missing header",
extractor: extractors.FromHeader("X-Missing-Header"),
expected: extractors.ErrNotFound,
setupCtx: func(_ *fasthttp.RequestCtx) {}, // No setup needed for headers
},
{
name: "Missing query",
extractor: extractors.FromQuery("missing_param"),
expected: extractors.ErrNotFound,
setupCtx: func(ctx *fasthttp.RequestCtx) {
ctx.Request.SetRequestURI("/") // Set URI for query parsing
},
},
{
name: "Missing param",
extractor: extractors.FromParam("missing_param"),
expected: extractors.ErrNotFound,
setupCtx: func(_ *fasthttp.RequestCtx) {}, // Params are handled by router
},
{
name: "Missing form",
extractor: extractors.FromForm("missing_field"),
expected: extractors.ErrNotFound,
setupCtx: func(ctx *fasthttp.RequestCtx) {
// Properly initialize request for form parsing
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetContentType(fiber.MIMEApplicationForm)
ctx.Request.SetBodyString("") // Empty form body
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
requestCtx := &fasthttp.RequestCtx{}
tc.setupCtx(requestCtx) // Set up the context properly
ctx := app.AcquireCtx(requestCtx)
defer app.ReleaseCtx(ctx)
token, err := tc.extractor.Extract(ctx)
require.Empty(t, token)
require.Equal(t, tc.expected, err)
})
}
}
================================================
FILE: middleware/csrf/helpers.go
================================================
package csrf
import (
"crypto/subtle"
"net/url"
"strings"
"github.com/gofiber/utils/v2"
utilsstrings "github.com/gofiber/utils/v2/strings"
)
const (
schemeHTTP = "http"
schemeHTTPS = "https"
)
func compareTokens(a, b []byte) bool {
return subtle.ConstantTimeCompare(a, b) == 1
}
func compareStrings(a, b string) bool {
return subtle.ConstantTimeCompare(utils.UnsafeBytes(a), utils.UnsafeBytes(b)) == 1
}
func schemeAndHostMatch(schemeA, hostA, schemeB, hostB string) bool {
normalizedSchemeA := utilsstrings.ToLower(schemeA)
normalizedSchemeB := utilsstrings.ToLower(schemeB)
normalizedHostA := normalizeSchemeHost(normalizedSchemeA, hostA)
normalizedHostB := normalizeSchemeHost(normalizedSchemeB, hostB)
return normalizedSchemeA == normalizedSchemeB && normalizedHostA == normalizedHostB
}
func normalizeSchemeHost(scheme, host string) string {
host = utilsstrings.ToLower(host)
defaultPort := ""
switch scheme {
case schemeHTTP:
defaultPort = "80"
case schemeHTTPS:
defaultPort = "443"
default:
return host
}
parsedHost, err := url.Parse(scheme + "://" + host)
if err != nil {
return host
}
if port := parsedHost.Port(); port != "" {
return host
}
hostname := parsedHost.Hostname()
if hostname == "" {
return host
}
if strings.IndexByte(hostname, ':') >= 0 && !strings.HasPrefix(hostname, "[") {
hostname = "[" + hostname + "]"
}
return hostname + ":" + defaultPort
}
// normalizeOrigin checks if the provided origin is in a correct format
// and normalizes it by removing any path or trailing slash.
// It returns a boolean indicating whether the origin is valid
// and the normalized origin.
func normalizeOrigin(origin string) (valid bool, normalized string) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming validity and normalized origin results
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false, ""
}
// Validate the scheme is either http or https
if parsedOrigin.Scheme != schemeHTTP && parsedOrigin.Scheme != schemeHTTPS {
return false, ""
}
// Don't allow a wildcard with a protocol
// wildcards cannot be used within any other value. For example, the following header is not valid:
// Access-Control-Allow-Origin: https://*
if strings.IndexByte(parsedOrigin.Host, '*') >= 0 {
return false, ""
}
// Validate there is a host present. The presence of a path, query, or fragment components
// is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized
if parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" {
return false, ""
}
// Normalize the origin by constructing it from the scheme and host.
// The path or trailing slash is not included in the normalized origin.
return true, utilsstrings.ToLower(parsedOrigin.Scheme) + "://" + utilsstrings.ToLower(parsedOrigin.Host)
}
type subdomain struct {
prefix string
suffix string
}
func (s subdomain) match(o string) bool {
// Not a subdomain if not long enough for a dot separator.
if len(o) < len(s.prefix)+len(s.suffix)+1 {
return false
}
if !strings.HasPrefix(o, s.prefix) || !strings.HasSuffix(o, s.suffix) {
return false
}
// Check for the dot separator and validate that there is at least one
// non-empty label between prefix and suffix. Empty labels like
// "https://.example.com" or "https://..example.com" should not match.
suffixStartIndex := len(o) - len(s.suffix)
if suffixStartIndex <= len(s.prefix) {
return false
}
if o[suffixStartIndex-1] != '.' {
return false
}
// Extract the subdomain part (without the trailing dot) and ensure it
// doesn't contain empty labels.
sub := o[len(s.prefix) : suffixStartIndex-1]
if sub == "" || sub[0] == '.' || strings.Contains(sub, "..") {
return false
}
return true
}
================================================
FILE: middleware/csrf/helpers_test.go
================================================
package csrf
import (
"testing"
"github.com/stretchr/testify/assert"
)
// go test -run -v Test_normalizeOrigin
func Test_normalizeOrigin(t *testing.T) {
t.Parallel()
testCases := []struct {
origin string
expectedOrigin string
expectedValid bool
}{
{origin: "http://example.com", expectedValid: true, expectedOrigin: "http://example.com"}, // Simple case should work.
{origin: "HTTP://EXAMPLE.COM", expectedValid: true, expectedOrigin: "http://example.com"}, // Case should be normalized.
{origin: "http://example.com/", expectedValid: true, expectedOrigin: "http://example.com"}, // Trailing slash should be removed.
{origin: "http://example.com:3000", expectedValid: true, expectedOrigin: "http://example.com:3000"}, // Port should be preserved.
{origin: "http://example.com:3000/", expectedValid: true, expectedOrigin: "http://example.com:3000"}, // Trailing slash should be removed.
{origin: "http://", expectedValid: false, expectedOrigin: ""}, // Invalid origin should not be accepted.
{origin: "file:///etc/passwd", expectedValid: false, expectedOrigin: ""}, // File scheme should not be accepted.
{origin: "https://*example.com", expectedValid: false, expectedOrigin: ""}, // Wildcard domain should not be accepted.
{origin: "http://*.example.com", expectedValid: false, expectedOrigin: ""}, // Wildcard subdomain should not be accepted.
{origin: "http://example.com/path", expectedValid: false, expectedOrigin: ""}, // Path should not be accepted.
{origin: "http://example.com?query=123", expectedValid: false, expectedOrigin: ""}, // Query should not be accepted.
{origin: "http://example.com#fragment", expectedValid: false, expectedOrigin: ""}, // Fragment should not be accepted.
{origin: "http://localhost", expectedValid: true, expectedOrigin: "http://localhost"}, // Localhost should be accepted.
{origin: "http://127.0.0.1", expectedValid: true, expectedOrigin: "http://127.0.0.1"}, // IPv4 address should be accepted.
{origin: "http://[::1]", expectedValid: true, expectedOrigin: "http://[::1]"}, // IPv6 address should be accepted.
{origin: "http://[::1]:8080", expectedValid: true, expectedOrigin: "http://[::1]:8080"}, // IPv6 address with port should be accepted.
{origin: "http://[::1]:8080/", expectedValid: true, expectedOrigin: "http://[::1]:8080"}, // IPv6 address with port and trailing slash should be accepted.
{origin: "http://[::1]:8080/path", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and path should not be accepted.
{origin: "http://[::1]:8080?query=123", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and query should not be accepted.
{origin: "http://[::1]:8080#fragment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and fragment should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, and fragment should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment/", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, and trailing slash should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment/invalid", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment/invalid/", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with trailing slash should not be accepted.
{origin: "http://[::1]:8080/path?query=123#fragment/invalid/segment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with additional segment should not be accepted.
}
for _, tc := range testCases {
valid, normalizedOrigin := normalizeOrigin(tc.origin)
if valid != tc.expectedValid {
t.Errorf("Expected origin '%s' to be valid: %v, but got: %v", tc.origin, tc.expectedValid, valid)
}
if normalizedOrigin != tc.expectedOrigin {
t.Errorf("Expected normalized origin '%s' for origin '%s', but got: '%s'", tc.expectedOrigin, tc.origin, normalizedOrigin)
}
}
}
func Test_normalizeSchemeHost(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
scheme string
host string
expectedHost string
}{
{
name: "http default port added",
scheme: "http",
host: "example.com",
expectedHost: "example.com:80",
},
{
name: "https default port added",
scheme: "https",
host: "example.com",
expectedHost: "example.com:443",
},
{
name: "http custom port preserved",
scheme: "http",
host: "example.com:8080",
expectedHost: "example.com:8080",
},
{
name: "https ipv6 default port added",
scheme: "https",
host: "[::1]",
expectedHost: "[::1]:443",
},
{
name: "unknown scheme preserved",
scheme: "ftp",
host: "example.com",
expectedHost: "example.com",
},
{
name: "https ipv6 custom port preserved",
scheme: "https",
host: "[::1]:8080",
expectedHost: "[::1]:8080",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tc.expectedHost, normalizeSchemeHost(tc.scheme, tc.host))
})
}
}
// go test -run -v TestSubdomainMatch
func TestSubdomainMatch(t *testing.T) {
t.Parallel()
tests := []struct {
name string
sub subdomain
origin string
expected bool
}{
{
name: "match with different scheme",
sub: subdomain{prefix: "http://api.", suffix: "example.com"},
origin: "https://api.service.example.com",
expected: false,
},
{
name: "match with different scheme",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "http://api.service.example.com",
expected: false,
},
{
name: "match with valid subdomain",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "https://api.service.example.com",
expected: true,
},
{
name: "match with valid nested subdomain",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "https://1.2.api.service.example.com",
expected: true,
},
{
name: "no match with invalid prefix",
sub: subdomain{prefix: "https://abc.", suffix: "example.com"},
origin: "https://service.example.com",
expected: false,
},
{
name: "no match with invalid suffix",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "https://api.example.org",
expected: false,
},
{
name: "no match with empty origin",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "",
expected: false,
},
{
name: "no match with malformed subdomain",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "https://evil.comexample.com",
expected: false,
},
{
name: "partial match not considered a match",
sub: subdomain{prefix: "https://service.", suffix: "example.com"},
origin: "https://api.example.com",
expected: false,
},
{
name: "no match with empty host label",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "https://.example.com",
expected: false,
},
{
name: "no match with malformed host label",
sub: subdomain{prefix: "https://", suffix: "example.com"},
origin: "https://..example.com",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := tt.sub.match(tt.origin)
assert.Equal(t, tt.expected, got, "subdomain.match()")
})
}
}
// go test -v -run=^$ -bench=Benchmark_CSRF_SubdomainMatch -benchmem -count=4
func Benchmark_CSRF_SubdomainMatch(b *testing.B) {
s := subdomain{
prefix: "www",
suffix: "example.com",
}
o := "www.example.com"
b.ReportAllocs()
for b.Loop() {
s.match(o)
}
}
================================================
FILE: middleware/csrf/session_manager.go
================================================
package csrf
import (
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/fiber/v3/middleware/session"
)
type sessionManager struct {
session *session.Store
}
type sessionKeyType int
const (
sessionKey sessionKeyType = 0
)
func newSessionManager(s *session.Store) *sessionManager {
// Create new storage handler
sessionManager := new(sessionManager)
if s != nil {
// Use provided storage if provided
sessionManager.session = s
// Register the sessionKeyType and Token type
s.RegisterType(sessionKeyType(0))
s.RegisterType(Token{})
}
return sessionManager
}
// get token from session
func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte {
sess := session.FromContext(c)
var token Token
var ok bool
if sess != nil {
token, ok = sess.Get(sessionKey).(Token)
} else {
// Try to get the session from the store
storeSess, err := m.session.Get(c)
if err != nil {
// Handle error
return nil
}
token, ok = storeSess.Get(sessionKey).(Token)
}
if ok {
if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) {
return nil
}
return token.Raw
}
return nil
}
// set token in session
func (m *sessionManager) setRaw(c fiber.Ctx, key string, raw []byte, exp time.Duration) {
sess := session.FromContext(c)
if sess != nil {
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
sess.Set(sessionKey, Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)})
} else {
// Try to get the session from the store
storeSess, err := m.session.Get(c)
if err != nil {
// Handle error
return
}
storeSess.Set(sessionKey, Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)})
if err := storeSess.Save(); err != nil {
log.Warn("csrf: failed to save session: ", err)
}
}
}
// delete token from session
func (m *sessionManager) delRaw(c fiber.Ctx) {
sess := session.FromContext(c)
if sess != nil {
sess.Delete(sessionKey)
} else {
// Try to get the session from the store
storeSess, err := m.session.Get(c)
if err != nil {
// Handle error
return
}
storeSess.Delete(sessionKey)
if err := storeSess.Save(); err != nil {
log.Warn("csrf: failed to save session: ", err)
}
}
}
================================================
FILE: middleware/csrf/storage_manager.go
================================================
package csrf
import (
"context"
"fmt"
"sync"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/memory"
)
// msgp -file="storage_manager.go" -o="storage_manager_msgp.go" -tests=true -unexported
//
//go:generate msgp -o=storage_manager_msgp.go -tests=true -unexported
type item struct{}
const redactedKey = "[redacted]"
//msgp:ignore manager
//msgp:ignore storageManager
type storageManager struct {
pool sync.Pool `msg:"-"` //nolint:revive // Ignore unexported type
memory *memory.Storage `msg:"-"` //nolint:revive // Ignore unexported type
storage fiber.Storage `msg:"-"` //nolint:revive // Ignore unexported type
redactKeys bool
}
func newStorageManager(storage fiber.Storage, redactKeys bool) *storageManager {
// Create new storage handler
storageManager := &storageManager{
pool: sync.Pool{
New: func() any {
return new(item)
},
},
redactKeys: redactKeys,
}
if storage != nil {
// Use provided storage if provided
storageManager.storage = storage
} else {
// Fallback to memory storage
storageManager.memory = memory.New()
}
return storageManager
}
// get raw data from storage or memory
func (m *storageManager) getRaw(ctx context.Context, key string) ([]byte, error) {
if m.storage != nil {
raw, err := m.storage.GetWithContext(ctx, key)
if err != nil {
return nil, fmt.Errorf("csrf: failed to get value from storage: %w", err)
}
return raw, nil
}
if value := m.memory.Get(key); value != nil {
raw, ok := value.([]byte)
if !ok {
return nil, fmt.Errorf("csrf: unexpected value type %T in storage", value)
}
return raw, nil
}
return nil, nil
}
// set data to storage or memory
func (m *storageManager) setRaw(ctx context.Context, key string, raw []byte, exp time.Duration) error {
if m.storage != nil {
if err := m.storage.SetWithContext(ctx, key, raw, exp); err != nil {
return fmt.Errorf("csrf: failed to store key %q: %w", m.logKey(key), err)
}
return nil
}
m.memory.Set(key, raw, exp)
return nil
}
// delete data from storage or memory
func (m *storageManager) delRaw(ctx context.Context, key string) error {
if m.storage != nil {
if err := m.storage.DeleteWithContext(ctx, key); err != nil {
return fmt.Errorf("csrf: failed to delete key %q: %w", m.logKey(key), err)
}
return nil
}
m.memory.Delete(key)
return nil
}
func (m *storageManager) logKey(key string) string {
if m.redactKeys {
return redactedKey
}
return key
}
================================================
FILE: middleware/csrf/storage_manager_msgp.go
================================================
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
package csrf
import (
"github.com/tinylib/msgp/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 0
_ = z
err = en.Append(0x80)
if err != nil {
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 0
_ = z
o = append(o, 0x80)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z item) Msgsize() (s int) {
s = 1
return
}
================================================
FILE: middleware/csrf/storage_manager_msgp_test.go
================================================
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
package csrf
import (
"bytes"
"testing"
"github.com/tinylib/msgp/msgp"
)
func TestMarshalUnmarshalitem(t *testing.T) {
v := item{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgitem(b *testing.B) {
v := item{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgitem(b *testing.B) {
v := item{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshalitem(b *testing.B) {
v := item{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}
func TestEncodeDecodeitem(t *testing.T) {
v := item{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
m := v.Msgsize()
if buf.Len() > m {
t.Log("WARNING: TestEncodeDecodeitem Msgsize() is inaccurate")
}
vn := item{}
err := msgp.Decode(&buf, &vn)
if err != nil {
t.Error(err)
}
buf.Reset()
msgp.Encode(&buf, &v)
err = msgp.NewReader(&buf).Skip()
if err != nil {
t.Error(err)
}
}
func BenchmarkEncodeitem(b *testing.B) {
v := item{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
en := msgp.NewWriter(msgp.Nowhere)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.EncodeMsg(en)
}
en.Flush()
}
func BenchmarkDecodeitem(b *testing.B) {
v := item{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
rd := msgp.NewEndlessReader(buf.Bytes(), b)
dc := msgp.NewReader(rd)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := v.DecodeMsg(dc)
if err != nil {
b.Fatal(err)
}
}
}
================================================
FILE: middleware/csrf/token.go
================================================
package csrf
import (
"time"
)
// Token represents a CSRF token with expiration metadata.
// This is used internally for token storage and validation.
type Token struct {
Expiration time.Time `json:"expiration"`
Key string `json:"key"`
Raw []byte `json:"raw"`
}
================================================
FILE: middleware/earlydata/config.go
================================================
package earlydata
import (
"github.com/gofiber/fiber/v3"
)
const (
DefaultHeaderName = "Early-Data"
DefaultHeaderTrueValue = "1"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// IsEarlyData returns whether the request is an early-data request.
//
// Optional. Default: a function which checks if the "Early-Data" request header equals "1".
IsEarlyData func(c fiber.Ctx) bool
// AllowEarlyData returns whether the early-data request should be allowed or rejected.
//
// Optional. Default: a function which rejects the request on unsafe and allows the request on safe HTTP request methods.
AllowEarlyData func(c fiber.Ctx) bool
// Error is returned if an early-data request is rejected.
//
// Optional. Default: fiber.ErrTooEarly.
Error error
}
// ConfigDefault is the default config
var ConfigDefault = Config{
IsEarlyData: func(c fiber.Ctx) bool {
return c.Get(DefaultHeaderName) == DefaultHeaderTrueValue
},
AllowEarlyData: func(c fiber.Ctx) bool {
return fiber.IsMethodSafe(c.Method())
},
Error: fiber.ErrTooEarly,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.IsEarlyData == nil {
cfg.IsEarlyData = ConfigDefault.IsEarlyData
}
if cfg.AllowEarlyData == nil {
cfg.AllowEarlyData = ConfigDefault.AllowEarlyData
}
if cfg.Error == nil {
cfg.Error = ConfigDefault.Error
}
return cfg
}
================================================
FILE: middleware/earlydata/earlydata.go
================================================
package earlydata
import (
"github.com/gofiber/fiber/v3"
)
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
const (
localsKeyAllowed contextKey = 0 // earlydata_allowed
)
// IsEarly returns true if the request used early data and was accepted by the middleware.
func IsEarly(c fiber.Ctx) bool {
return c.Locals(localsKeyAllowed) != nil
}
// New creates a new middleware handler
// https://datatracker.ietf.org/doc/html/rfc8470#section-5.1
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Continue stack if request is not an early-data request
if !cfg.IsEarlyData(c) {
return c.Next()
}
// Abort if we can't trust the early-data header
if !c.IsProxyTrusted() {
return cfg.Error
}
// Continue stack if we allow early-data for this request
if cfg.AllowEarlyData(c) {
_ = c.Locals(localsKeyAllowed, true)
return c.Next()
}
// Else return our error
return cfg.Error
}
}
================================================
FILE: middleware/earlydata/earlydata_test.go
================================================
package earlydata
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
const (
headerName = "Early-Data"
headerValOn = "1"
headerValOff = "0"
)
const (
trustedRemoteAddr = "0.0.0.0:1234"
untrustedRemoteAddr = "203.0.113.1:1234"
)
func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App {
t.Helper()
var app *fiber.App
if c == nil {
app = fiber.New()
} else {
app = fiber.New(*c)
}
app.Use(New())
// Middleware to test IsEarly func
const localsKeyTestValid = "earlydata_testvalid"
app.Use(func(c fiber.Ctx) error {
isEarly := IsEarly(c)
switch h := c.Get(headerName); h {
case "", headerValOff:
if isEarly {
return errors.New("is early-data even though it's not")
}
case headerValOn:
switch {
case fiber.IsMethodSafe(c.Method()):
if !isEarly {
return errors.New("should be early-data on safe HTTP methods")
}
default:
if isEarly {
return errors.New("early-data unsupported on unsafe HTTP methods")
}
}
default:
return fmt.Errorf("header has unsupported value: %s", h)
}
_ = c.Locals(localsKeyTestValid, true)
return c.Next()
})
app.Add([]string{
fiber.MethodGet,
fiber.MethodPost,
}, "/", func(c fiber.Ctx) error {
valid, ok := c.Locals(localsKeyTestValid).(bool)
if !ok {
panic(errors.New("failed to type-assert to bool"))
}
if !valid {
return errors.New("handler called even though validation failed")
}
return nil
})
return app
}
type requestExpectation struct {
method string
header string
status int
}
func executeExpectations(t *testing.T, app *fiber.App, remoteAddr string, expectations []requestExpectation) {
t.Helper()
for _, expectation := range expectations {
req := httptest.NewRequest(expectation.method, "/", http.NoBody)
req.RemoteAddr = remoteAddr
if expectation.header != "" {
req.Header.Set(headerName, expectation.header)
}
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, expectation.status, resp.StatusCode)
}
}
// go test -run Test_EarlyData
func Test_EarlyData(t *testing.T) {
t.Parallel()
untrustedExpectations := []requestExpectation{
{method: fiber.MethodGet, status: fiber.StatusOK},
{method: fiber.MethodGet, header: headerValOff, status: fiber.StatusOK},
{method: fiber.MethodGet, header: headerValOn, status: fiber.StatusTooEarly},
{method: fiber.MethodPost, status: fiber.StatusOK},
{method: fiber.MethodPost, header: headerValOff, status: fiber.StatusOK},
{method: fiber.MethodPost, header: headerValOn, status: fiber.StatusTooEarly},
}
trustedExpectations := []requestExpectation{
{method: fiber.MethodGet, status: fiber.StatusOK},
{method: fiber.MethodGet, header: headerValOff, status: fiber.StatusOK},
{method: fiber.MethodGet, header: headerValOn, status: fiber.StatusOK},
{method: fiber.MethodPost, status: fiber.StatusOK},
{method: fiber.MethodPost, header: headerValOff, status: fiber.StatusOK},
{method: fiber.MethodPost, header: headerValOn, status: fiber.StatusTooEarly},
}
t.Run("empty config", func(t *testing.T) {
t.Parallel()
app := appWithConfig(t, nil)
executeExpectations(t, app, untrustedRemoteAddr, untrustedExpectations)
})
t.Run("default config", func(t *testing.T) {
t.Parallel()
app := appWithConfig(t, &fiber.Config{})
executeExpectations(t, app, untrustedRemoteAddr, untrustedExpectations)
})
t.Run("config with TrustProxy and untrusted remote", func(t *testing.T) {
t.Parallel()
app := appWithConfig(t, &fiber.Config{
TrustProxy: true,
})
executeExpectations(t, app, untrustedRemoteAddr, untrustedExpectations)
})
t.Run("config with TrustProxy and trusted TrustProxyConfig.Proxies", func(t *testing.T) {
t.Parallel()
app := appWithConfig(t, &fiber.Config{
TrustProxy: true,
TrustProxyConfig: fiber.TrustProxyConfig{
Proxies: []string{
"0.0.0.0",
},
},
})
executeExpectations(t, app, trustedRemoteAddr, trustedExpectations)
})
}
// Test_EarlyDataNext verifies that the middleware skips its logic when Next returns true.
func Test_EarlyDataNext(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(fiber.Ctx) bool { return true },
}))
called := false
app.Get("/", func(c fiber.Ctx) error {
called = true
if IsEarly(c) {
return errors.New("IsEarly(c) should be false when Next returns true")
}
return nil
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set(headerName, headerValOn)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.True(t, called)
}
// Test_configDefault_NoConfig verifies that calling configDefault without
// providing a configuration returns ConfigDefault as-is.
func Test_configDefault_NoConfig(t *testing.T) {
t.Parallel()
cfg := configDefault()
require.Equal(t, ConfigDefault.Error, cfg.Error)
require.Equal(t, reflect.ValueOf(ConfigDefault.IsEarlyData).Pointer(), reflect.ValueOf(cfg.IsEarlyData).Pointer())
require.Equal(t, reflect.ValueOf(ConfigDefault.AllowEarlyData).Pointer(), reflect.ValueOf(cfg.AllowEarlyData).Pointer())
}
// Test_configDefault_WithConfig verifies that provided configuration fields are
// kept while missing fields are populated with defaults.
func Test_configDefault_WithConfig(t *testing.T) {
t.Parallel()
expectedErr := errors.New("boom")
called := false
custom := Config{
Next: func(_ fiber.Ctx) bool { called = true; return false },
Error: expectedErr,
}
cfg := configDefault(custom)
// Next should be preserved and not invoked by configDefault.
require.False(t, called)
require.Equal(t, reflect.ValueOf(custom.Next).Pointer(), reflect.ValueOf(cfg.Next).Pointer())
// Custom error must be preserved.
require.Equal(t, expectedErr, cfg.Error)
// Missing fields should be set to defaults.
require.NotNil(t, cfg.IsEarlyData)
require.NotNil(t, cfg.AllowEarlyData)
// Verify default functions behave as expected.
app := fiber.New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set(DefaultHeaderName, DefaultHeaderTrueValue)
c.Request().Header.SetMethod(fiber.MethodGet)
require.True(t, cfg.IsEarlyData(c))
require.True(t, cfg.AllowEarlyData(c))
app.ReleaseCtx(c)
}
================================================
FILE: middleware/encryptcookie/config.go
================================================
package encryptcookie
import (
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// Custom function to encrypt cookies.
//
// Optional. Default: EncryptCookie (using AES-GCM)
Encryptor func(name, decryptedString, key string) (string, error)
// Custom function to decrypt cookies.
//
// Optional. Default: DecryptCookie (using AES-GCM)
Decryptor func(name, encryptedString, key string) (string, error)
// Base64 encoded unique key to encode & decode cookies.
//
// Required. Key length should be 16, 24, or 32 bytes when decoded
// if using the default EncryptCookie and DecryptCookie functions.
// You may use `encryptcookie.GenerateKey(length)` to generate a new key.
Key string
// Array of cookie keys that should not be encrypted.
//
// Optional. Default: []
Except []string
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Except: []string{},
Key: "",
Encryptor: EncryptCookie,
Decryptor: DecryptCookie,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Except == nil {
cfg.Except = ConfigDefault.Except
}
if cfg.Encryptor == nil {
cfg.Encryptor = ConfigDefault.Encryptor
}
if cfg.Decryptor == nil {
cfg.Decryptor = ConfigDefault.Decryptor
}
}
if cfg.Key == "" {
panic("fiber: encrypt cookie middleware requires key")
}
if err := validateKey(cfg.Key); err != nil {
panic(err)
}
return cfg
}
================================================
FILE: middleware/encryptcookie/config_test.go
================================================
package encryptcookie
import (
"crypto/rand"
"encoding/base64"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func Test_configDefault_KeyValidation(t *testing.T) {
t.Parallel()
t.Run("invalid base64", func(t *testing.T) {
t.Parallel()
_, decErr := base64.StdEncoding.DecodeString("invalid")
expectedErr := fmt.Errorf("failed to base64-decode key: %w", decErr).Error()
require.PanicsWithError(t, expectedErr, func() {
configDefault(Config{Key: "invalid"})
})
})
t.Run("invalid length", func(t *testing.T) {
t.Parallel()
key := make([]byte, 20)
_, err := rand.Read(key)
require.NoError(t, err)
invalidKey := base64.StdEncoding.EncodeToString(key)
require.PanicsWithValue(t, ErrInvalidKeyLength, func() {
configDefault(Config{Key: invalidKey})
})
})
}
================================================
FILE: middleware/encryptcookie/encryptcookie.go
================================================
package encryptcookie
import (
"errors"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v3"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Decrypt request cookies
cookiesToDelete := make([][]byte, 0, 4)
for key, value := range c.Request().Header.Cookies() {
keyString := string(key)
if !isDisabled(keyString, cfg.Except) {
decryptedValue, err := cfg.Decryptor(keyString, string(value), cfg.Key)
if err != nil {
cookiesToDelete = append(cookiesToDelete, key)
} else {
c.Request().Header.SetCookie(keyString, decryptedValue)
}
}
}
// Delete cookies that failed to decrypt - outside the loop to avoid mutation during iteration
for _, key := range cookiesToDelete {
c.Request().Header.DelCookieBytes(key)
}
// Continue stack
err := c.Next()
// Encrypt response cookies
for key := range c.Response().Header.Cookies() {
keyString := string(key)
if !isDisabled(keyString, cfg.Except) {
cookieValue := fasthttp.Cookie{}
cookieValue.SetKeyBytes(key)
if c.Response().Header.Cookie(&cookieValue) {
encryptedValue, encErr := cfg.Encryptor(keyString, string(cookieValue.Value()), cfg.Key)
if encErr != nil {
return errors.Join(err, encErr)
}
cookieValue.SetValue(encryptedValue)
c.Response().Header.SetCookie(&cookieValue)
}
}
}
return err
}
}
================================================
FILE: middleware/encryptcookie/encryptcookie_test.go
================================================
package encryptcookie
import (
"crypto/rand"
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v3"
)
func Test_Middleware_Panics(t *testing.T) {
t.Parallel()
t.Run("Empty Key", func(t *testing.T) {
t.Parallel()
app := fiber.New()
require.Panics(t, func() {
app.Use(New(Config{
Key: "",
}))
})
})
t.Run("Invalid Key", func(t *testing.T) {
t.Parallel()
require.Panics(t, func() {
GenerateKey(11)
})
})
}
func Test_Middleware_InvalidKeys(t *testing.T) {
t.Parallel()
tests := []struct {
length int
}{
{length: 11},
{length: 25},
{length: 60},
}
for _, tt := range tests {
t.Run(strconv.Itoa(tt.length)+"_length_encrypt", func(t *testing.T) {
t.Parallel()
key := make([]byte, tt.length)
_, err := rand.Read(key)
require.NoError(t, err)
keyString := base64.StdEncoding.EncodeToString(key)
_, err = EncryptCookie("test", "SomeThing", keyString)
require.Error(t, err)
})
t.Run(strconv.Itoa(tt.length)+"_length_decrypt", func(t *testing.T) {
t.Parallel()
key := make([]byte, tt.length)
_, err := rand.Read(key)
require.NoError(t, err)
keyString := base64.StdEncoding.EncodeToString(key)
_, err = DecryptCookie("test", "SomeThing", keyString)
require.Error(t, err)
})
}
}
func Test_Middleware_InvalidBase64(t *testing.T) {
t.Parallel()
invalidBase64 := "invalid-base64-string-!@#"
t.Run("encryptor", func(t *testing.T) {
t.Parallel()
_, err := EncryptCookie("test", "SomeText", invalidBase64)
require.Error(t, err)
require.ErrorContains(t, err, "failed to base64-decode key")
})
t.Run("decryptor_key", func(t *testing.T) {
t.Parallel()
_, err := DecryptCookie("test", "SomeText", invalidBase64)
require.Error(t, err)
require.ErrorContains(t, err, "failed to base64-decode key")
})
t.Run("decryptor_value", func(t *testing.T) {
t.Parallel()
_, err := DecryptCookie("test", invalidBase64, GenerateKey(32))
require.Error(t, err)
require.ErrorContains(t, err, "failed to base64-decode value")
})
}
func Test_DecryptCookie_InvalidEncryptedValue(t *testing.T) {
t.Parallel()
key := GenerateKey(32)
// the decoded value is shorter than the GCM nonce size, so decryption should fail immediately
shortValue := base64.StdEncoding.EncodeToString([]byte("short"))
_, err := DecryptCookie("session", shortValue, key)
require.ErrorIs(t, err, ErrInvalidEncryptedValue)
}
func Test_Middleware_EncryptionErrorPropagates(t *testing.T) {
t.Parallel()
testKey := GenerateKey(32)
expected := errors.New("encrypt failed")
var captured error
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
captured = err
return c.Status(fiber.StatusTeapot).SendString("encryption error")
},
})
app.Use(New(Config{
Key: testKey,
Encryptor: func(name, value, _ string) (string, error) {
if name == "test" {
return "", expected
}
return value, nil
},
}))
app.Get("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "value",
})
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
require.ErrorIs(t, captured, expected)
}
func Test_Middleware_EncryptionErrorDoesNotMaskNextError(t *testing.T) {
t.Parallel()
testKey := GenerateKey(32)
encryptErr := errors.New("encrypt failed")
downstreamErr := errors.New("downstream failed")
var captured error
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
captured = err
return c.Status(fiber.StatusTeapot).SendString("combined error")
},
})
app.Use(New(Config{
Key: testKey,
Encryptor: func(name, value, _ string) (string, error) {
if name == "test" {
return "", encryptErr
}
return value, nil
},
}))
app.Get("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "value",
})
return downstreamErr
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
require.ErrorIs(t, captured, downstreamErr)
require.ErrorIs(t, captured, encryptErr)
}
func Test_Middleware_Encrypt_Cookie(t *testing.T) {
t.Parallel()
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("value=" + c.Cookies("test"))
})
app.Post("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
// Test empty cookie
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
require.Equal(t, "value=", string(ctx.Response.Body()))
// Test invalid cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", "Invalid")
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
require.Equal(t, "value=", string(ctx.Response.Body()))
ctx.Request.Header.SetCookie("test", "ixQURE2XOyZUs0WAOh2ehjWcP7oZb07JvnhWOsmeNUhPsj4+RyI=")
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
require.Equal(t, "value=", string(ctx.Response.Body()))
// Test valid cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test")
require.True(t, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
decryptedCookieValue, err := DecryptCookie("test", string(encryptedCookie.Value()), testKey)
require.NoError(t, err)
require.Equal(t, "SomeThing", decryptedCookieValue)
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
require.Equal(t, "value=SomeThing", string(ctx.Response.Body()))
}
func Test_EncryptCookie_Rejects_Swapped_Names(t *testing.T) {
t.Parallel()
testKey := GenerateKey(32)
encryptedValue, err := EncryptCookie("cookieA", "ValueA", testKey)
require.NoError(t, err)
decryptedValue, err := DecryptCookie("cookieA", encryptedValue, testKey)
require.NoError(t, err)
require.Equal(t, "ValueA", decryptedValue)
_, err = DecryptCookie("cookieB", encryptedValue, testKey)
require.Error(t, err)
require.ErrorContains(t, err, "failed to decrypt ciphertext")
}
func Test_Encrypt_Cookie_Next(t *testing.T) {
t.Parallel()
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
Next: func(_ fiber.Ctx) bool {
return true
},
}))
app.Get("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, "SomeThing", resp.Cookies()[0].Value)
}
func Test_Encrypt_Cookie_Except(t *testing.T) {
t.Parallel()
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
Except: []string{
"test1",
},
}))
app.Get("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test1",
Value: "SomeThing",
})
c.Cookie(&fiber.Cookie{
Name: "test2",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
rawCookie := fasthttp.Cookie{}
rawCookie.SetKey("test1")
require.True(t, ctx.Response.Header.Cookie(&rawCookie), "Get cookie value")
require.Equal(t, "SomeThing", string(rawCookie.Value()))
encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test2")
require.True(t, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
decryptedCookieValue, err := DecryptCookie("test2", string(encryptedCookie.Value()), testKey)
require.NoError(t, err)
require.Equal(t, "SomeThing", decryptedCookieValue)
}
func Test_Encrypt_Cookie_Custom_Encryptor(t *testing.T) {
t.Parallel()
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
Encryptor: func(_, decryptedString, _ string) (string, error) {
return base64.StdEncoding.EncodeToString([]byte(decryptedString)), nil
},
Decryptor: func(_, encryptedString, _ string) (string, error) {
decodedBytes, err := base64.StdEncoding.DecodeString(encryptedString)
return string(decodedBytes), err
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("value=" + c.Cookies("test"))
})
app.Post("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test")
require.True(t, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
decodedBytes, err := base64.StdEncoding.DecodeString(string(encryptedCookie.Value()))
require.NoError(t, err)
require.Equal(t, "SomeThing", string(decodedBytes))
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
require.Equal(t, "value=SomeThing", string(ctx.Response.Body()))
}
func Test_GenerateKey(t *testing.T) {
t.Parallel()
tests := []struct {
length int
}{
{length: 16},
{length: 24},
{length: 32},
}
decodeBase64 := func(t *testing.T, s string) []byte {
t.Helper()
data, err := base64.StdEncoding.DecodeString(s)
require.NoError(t, err)
return data
}
for _, tt := range tests {
t.Run(strconv.Itoa(tt.length)+"_length", func(t *testing.T) {
t.Parallel()
key := GenerateKey(tt.length)
decodedKey := decodeBase64(t, key)
require.Len(t, decodedKey, tt.length)
})
}
t.Run("Invalid Length", func(t *testing.T) {
require.Panics(t, func() { GenerateKey(10) })
require.Panics(t, func() { GenerateKey(20) })
})
}
func Benchmark_Middleware_Encrypt_Cookie(b *testing.B) {
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("value=" + c.Cookies("test"))
})
app.Post("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
b.Run("Empty Cookie", func(b *testing.B) {
for b.Loop() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
}
})
b.Run("Invalid Cookie", func(b *testing.B) {
for b.Loop() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", "Invalid")
h(ctx)
}
})
b.Run("Valid Cookie", func(b *testing.B) {
for b.Loop() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
}
})
}
func Benchmark_Encrypt_Cookie_Next(b *testing.B) {
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
Next: func(_ fiber.Ctx) bool {
return true
},
}))
app.Get("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
b.Run("Encrypt Cookie Next", func(b *testing.B) {
for b.Loop() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/")
h(ctx)
}
})
}
func Benchmark_Encrypt_Cookie_Except(b *testing.B) {
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
Except: []string{
"test1",
},
}))
app.Get("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test1",
Value: "SomeThing",
})
c.Cookie(&fiber.Cookie{
Name: "test2",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
b.Run("Encrypt Cookie Except", func(b *testing.B) {
for b.Loop() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
}
})
}
func Benchmark_Encrypt_Cookie_Custom_Encryptor(b *testing.B) {
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
Encryptor: func(_, decryptedString, _ string) (string, error) {
return base64.StdEncoding.EncodeToString([]byte(decryptedString)), nil
},
Decryptor: func(_, encryptedString, _ string) (string, error) {
decodedBytes, err := base64.StdEncoding.DecodeString(encryptedString)
return string(decodedBytes), err
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("value=" + c.Cookies("test"))
})
app.Post("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
b.Run("Custom Encryptor Post", func(b *testing.B) {
for b.Loop() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
}
})
b.Run("Custom Encryptor Get", func(b *testing.B) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test")
require.True(b, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
for b.Loop() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
h(ctx)
}
})
}
func Benchmark_Middleware_Encrypt_Cookie_Parallel(b *testing.B) {
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("value=" + c.Cookies("test"))
})
app.Post("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
b.Run("Empty Cookie Parallel", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
}
})
})
b.Run("Invalid Cookie Parallel", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", "Invalid")
h(ctx)
}
})
})
b.Run("Valid Cookie Parallel", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
}
})
})
}
func Benchmark_Encrypt_Cookie_Next_Parallel(b *testing.B) {
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
Next: func(_ fiber.Ctx) bool {
return true
},
}))
app.Get("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/")
h(ctx)
}
})
}
func Benchmark_Encrypt_Cookie_Except_Parallel(b *testing.B) {
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
Except: []string{
"test1",
},
}))
app.Get("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test1",
Value: "SomeThing",
})
c.Cookie(&fiber.Cookie{
Name: "test2",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
}
})
}
func Benchmark_Encrypt_Cookie_Custom_Encryptor_Parallel(b *testing.B) {
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
Encryptor: func(_, decryptedString, _ string) (string, error) {
return base64.StdEncoding.EncodeToString([]byte(decryptedString)), nil
},
Decryptor: func(_, encryptedString, _ string) (string, error) {
decodedBytes, err := base64.StdEncoding.DecodeString(encryptedString)
return string(decodedBytes), err
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("value=" + c.Cookies("test"))
})
app.Post("/", func(c fiber.Ctx) error {
c.Cookie(&fiber.Cookie{
Name: "test",
Value: "SomeThing",
})
return nil
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
encryptedCookie := fasthttp.Cookie{}
encryptedCookie.SetKey("test")
require.True(b, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
b.ResetTimer()
for pb.Next() {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
h(ctx)
}
})
}
func Benchmark_GenerateKey(b *testing.B) {
tests := []struct {
length int
}{
{length: 16},
{length: 24},
{length: 32},
}
for _, tt := range tests {
b.Run(strconv.Itoa(tt.length), func(b *testing.B) {
for b.Loop() {
GenerateKey(tt.length)
}
})
}
}
func Benchmark_GenerateKey_Parallel(b *testing.B) {
tests := []struct {
length int
}{
{length: 16},
{length: 24},
{length: 32},
}
for _, tt := range tests {
b.Run(strconv.Itoa(tt.length), func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
GenerateKey(tt.length)
}
})
})
}
}
// Test_Middleware_Mixed_Valid_Invalid_Cookies tests that the middleware correctly handles
// a mix of valid and invalid cookies during iteration
func Test_Middleware_Mixed_Valid_Invalid_Cookies(t *testing.T) {
t.Parallel()
testKey := GenerateKey(32)
app := fiber.New()
app.Use(New(Config{
Key: testKey,
}))
app.Get("/", func(c fiber.Ctx) error {
valid1 := c.Cookies("valid1")
valid2 := c.Cookies("valid2")
invalid := c.Cookies("invalid")
return c.SendString("valid1=" + valid1 + ",valid2=" + valid2 + ",invalid=" + invalid)
})
h := app.Handler()
// First, create some valid encrypted cookies
encryptedValue1, err := EncryptCookie("valid1", "value1", testKey)
require.NoError(t, err)
encryptedValue2, err := EncryptCookie("valid2", "value2", testKey)
require.NoError(t, err)
// Test with a mix of valid and invalid cookies
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("valid1", encryptedValue1)
ctx.Request.Header.SetCookie("invalid", "thisisnotvalid")
ctx.Request.Header.SetCookie("valid2", encryptedValue2)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
require.Equal(t, "valid1=value1,valid2=value2,invalid=", string(ctx.Response.Body()))
// Verify the invalid cookie was deleted but valid ones remain
require.NotEmpty(t, ctx.Request.Header.Cookie("valid1"))
require.Empty(t, ctx.Request.Header.Cookie("invalid"))
require.NotEmpty(t, ctx.Request.Header.Cookie("valid2"))
}
================================================
FILE: middleware/encryptcookie/utils.go
================================================
package encryptcookie
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"slices"
)
var (
ErrInvalidKeyLength = errors.New("encryption key must be 16, 24, or 32 bytes")
ErrInvalidEncryptedValue = errors.New("encrypted value is not valid")
)
// decodeKey decodes the provided base64-encoded key and validates its length.
// It returns the decoded key bytes or an error when invalid.
func decodeKey(key string) ([]byte, error) {
keyDecoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, fmt.Errorf("failed to base64-decode key: %w", err)
}
keyLen := len(keyDecoded)
if keyLen != 16 && keyLen != 24 && keyLen != 32 {
return nil, ErrInvalidKeyLength
}
return keyDecoded, nil
}
// validateKey checks if the provided base64-encoded key is of valid length.
func validateKey(key string) error {
_, err := decodeKey(key)
return err
}
// EncryptCookie Encrypts a cookie value with specific encryption key
func EncryptCookie(name, value, key string) (string, error) {
keyDecoded, err := decodeKey(key)
if err != nil {
return "", err
}
block, err := aes.NewCipher(keyDecoded)
if err != nil {
return "", fmt.Errorf("failed to create AES cipher: %w", err)
}
gcm, err := cipher.NewGCMWithRandomNonce(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM mode: %w", err)
}
ciphertext := gcm.Seal(nil, nil, []byte(value), []byte(name))
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// DecryptCookie Decrypts a cookie value with specific encryption key
func DecryptCookie(name, value, key string) (string, error) {
keyDecoded, err := decodeKey(key)
if err != nil {
return "", err
}
enc, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return "", fmt.Errorf("failed to base64-decode value: %w", err)
}
block, err := aes.NewCipher(keyDecoded)
if err != nil {
return "", fmt.Errorf("failed to create AES cipher: %w", err)
}
gcm, err := cipher.NewGCMWithRandomNonce(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM mode: %w", err)
}
if len(enc) < gcm.NonceSize()+gcm.Overhead() {
return "", ErrInvalidEncryptedValue
}
plaintext, err := gcm.Open(nil, nil, enc, []byte(name))
if err != nil {
return "", fmt.Errorf("failed to decrypt ciphertext: %w", err)
}
return string(plaintext), nil
}
// GenerateKey returns a random string of 16, 24, or 32 bytes.
// The length of the key determines the AES encryption algorithm used:
// 16 bytes for AES-128, 24 bytes for AES-192, and 32 bytes for AES-256-GCM.
func GenerateKey(length int) string {
if length != 16 && length != 24 && length != 32 {
panic(ErrInvalidKeyLength)
}
key := make([]byte, length)
if _, err := rand.Read(key); err != nil {
panic(err)
}
return base64.StdEncoding.EncodeToString(key)
}
// Check given cookie key is disabled for encryption or not
func isDisabled(key string, except []string) bool {
return slices.Contains(except, key)
}
================================================
FILE: middleware/envvar/config.go
================================================
package envvar
// Config defines the config for middleware.
type Config struct {
// ExportVars specifies the environment variables that should export
ExportVars map[string]string
}
// ConfigDefault is the default config.
var ConfigDefault = Config{
ExportVars: map[string]string{},
}
func configDefault(config ...Config) Config {
if len(config) == 0 {
return ConfigDefault
}
cfg := config[0]
if cfg.ExportVars == nil {
cfg.ExportVars = ConfigDefault.ExportVars
}
return cfg
}
================================================
FILE: middleware/envvar/envvar.go
================================================
package envvar
import (
"os"
"github.com/gofiber/fiber/v3"
)
const hAllow = fiber.MethodGet + ", " + fiber.MethodHead
// EnvVar captures environment variables that are exposed through the
// middleware response.
type EnvVar struct {
Vars map[string]string `json:"vars"`
}
func (envVar *EnvVar) set(key, val string) {
envVar.Vars[key] = val
}
// New creates a handler that returns configured environment variables as a
// JSON response.
func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)
return func(c fiber.Ctx) error {
method := c.Method()
if method != fiber.MethodGet && method != fiber.MethodHead {
c.Set(fiber.HeaderAllow, hAllow)
return fiber.ErrMethodNotAllowed
}
envVar := newEnvVar(cfg)
varsByte, err := c.App().Config().JSONEncoder(envVar)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(err.Error())
}
c.Set(fiber.HeaderContentType, fiber.MIMEApplicationJSONCharsetUTF8)
return c.Send(varsByte)
}
}
func newEnvVar(cfg Config) *EnvVar {
vars := &EnvVar{Vars: make(map[string]string)}
if len(cfg.ExportVars) == 0 {
// do not expose environment variables when no configuration
// is supplied to prevent accidental information disclosure
return vars
}
for key, defaultVal := range cfg.ExportVars {
vars.set(key, defaultVal)
if envVal, exists := os.LookupEnv(key); exists {
vars.set(key, envVal)
}
}
return vars
}
================================================
FILE: middleware/envvar/envvar_test.go
================================================
package envvar
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
func Test_EnvVarStructWithExportVars(t *testing.T) {
t.Setenv("testKey", "testEnvValue")
t.Setenv("anotherEnvKey", "anotherEnvVal")
vars := newEnvVar(Config{
ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"},
})
require.Equal(t, "testEnvValue", vars.Vars["testKey"])
require.Equal(t, "testDefaultVal", vars.Vars["testDefaultKey"])
require.Empty(t, vars.Vars["anotherEnvKey"])
}
func Test_EnvVarHandler(t *testing.T) {
t.Setenv("testKey", "testVal")
expectedEnvVarResponse, err := json.Marshal(
struct {
Vars map[string]string `json:"vars"`
}{
Vars: map[string]string{"testKey": "testVal"},
})
require.NoError(t, err)
app := fiber.New()
app.Use("/envvars", New(Config{
ExportVars: map[string]string{"testKey": ""},
}))
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, expectedEnvVarResponse, respBody)
}
func Test_EnvVarHandlerNotMatched(t *testing.T) {
app := fiber.New()
app.Use("/envvars", New(Config{
ExportVars: map[string]string{"testKey": ""},
}))
app.Get("/another-path", func(ctx fiber.Ctx) error {
require.NoError(t, ctx.SendString("OK"))
return nil
})
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/another-path", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, []byte("OK"), respBody)
}
func Test_EnvVarHandlerDefaultConfig(t *testing.T) {
t.Setenv("testEnvKey", "testEnvVal")
app := fiber.New()
app.Use("/envvars", New())
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var envVars EnvVar
require.NoError(t, json.Unmarshal(respBody, &envVars))
_, exists := envVars.Vars["testEnvKey"]
require.False(t, exists)
}
func Test_EnvVarHandlerMethod(t *testing.T) {
app := fiber.New()
app.Use("/envvars", New())
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "http://localhost/envvars", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusMethodNotAllowed, resp.StatusCode)
require.Equal(t, hAllow, resp.Header.Get(fiber.HeaderAllow))
}
func Test_EnvVarHandlerHead(t *testing.T) {
app := fiber.New()
app.Use("/envvars", New())
req := httptest.NewRequest(fiber.MethodHead, "http://localhost/envvars", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Empty(t, string(body))
}
func Test_EnvVarHandlerSpecialValue(t *testing.T) {
testEnvKey := "testEnvKey"
fakeBase64 := "testBase64:TQ=="
t.Setenv(testEnvKey, fakeBase64)
app := fiber.New()
app.Use("/envvars/export", New(Config{ExportVars: map[string]string{testEnvKey: ""}}))
app.Use("/envvars", New())
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var envVars EnvVar
require.NoError(t, json.Unmarshal(respBody, &envVars))
_, exists := envVars.Vars[testEnvKey]
require.False(t, exists)
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars/export", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
respBody, err = io.ReadAll(resp.Body)
require.NoError(t, err)
var envVarsExport EnvVar
require.NoError(t, json.Unmarshal(respBody, &envVarsExport))
val := envVarsExport.Vars[testEnvKey]
require.Equal(t, fakeBase64, val)
}
================================================
FILE: middleware/etag/config.go
================================================
package etag
import (
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// Weak indicates that a weak validator is used. Weak etags are easy
// to generate, but are far less useful for comparisons. Strong
// validators are ideal for comparisons but can be very difficult
// to generate efficiently. Weak ETag values of two representations
// of the same resources might be semantically equivalent, but not
// byte-for-byte identical. This means weak etags prevent caching
// when byte range requests are used, but strong etags mean range
// requests can still be cached.
Weak bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Weak: false,
Next: nil,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
return cfg
}
================================================
FILE: middleware/etag/etag.go
================================================
package etag
import (
"bytes"
"hash/crc32"
"math"
"slices"
"github.com/gofiber/fiber/v3"
"github.com/valyala/bytebufferpool"
)
var (
weakPrefix = []byte("W/")
crc32q = crc32.MakeTable(0xD5828281)
)
// Generate returns a strong ETag for body.
func Generate(body []byte) []byte {
if uint64(len(body)) > math.MaxUint32 {
return nil
}
bb := bytebufferpool.Get()
defer bytebufferpool.Put(bb)
b := bb.B[:0]
b = append(b, '"')
b = appendUint(b, uint32(len(body))) // #nosec G115 -- length checked above
b = append(b, '-')
b = appendUint(b, crc32.Checksum(body, crc32q))
b = append(b, '"')
return slices.Clone(b)
}
// GenerateWeak returns a weak ETag for body.
func GenerateWeak(body []byte) []byte {
tag := Generate(body)
if tag == nil {
return nil
}
return append(weakPrefix, tag...)
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
normalizedHeaderETag := []byte("Etag")
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Return err if next handler returns one
if err := c.Next(); err != nil {
return err
}
// Don't generate ETags for invalid responses
if c.Response().StatusCode() != fiber.StatusOK {
return nil
}
body := c.Response().Body()
// Skips ETag if no response body is present
if len(body) == 0 {
return nil
}
// Skip ETag if header is already present
if c.Response().Header.PeekBytes(normalizedHeaderETag) != nil {
return nil
}
bodyLength := len(body)
if uint64(bodyLength) > math.MaxUint32 {
return c.SendStatus(fiber.StatusRequestEntityTooLarge)
}
var etag []byte
if cfg.Weak {
etag = GenerateWeak(body)
} else {
etag = Generate(body)
}
// Get ETag header from request
clientEtag := c.Request().Header.Peek(fiber.HeaderIfNoneMatch)
// Check if client's ETag is weak
if bytes.HasPrefix(clientEtag, weakPrefix) {
// Check if server's ETag is weak
if bytes.Equal(clientEtag[2:], etag) || bytes.Equal(clientEtag[2:], etag[2:]) {
// W/1 == 1 || W/1 == W/1
c.RequestCtx().ResetBody()
return c.SendStatus(fiber.StatusNotModified)
}
// W/1 != W/2 || W/1 != 2
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
return nil
}
if bytes.Contains(clientEtag, etag) {
// 1 == 1
c.RequestCtx().ResetBody()
return c.SendStatus(fiber.StatusNotModified)
}
// 1 != 2
c.Response().Header.SetCanonical(normalizedHeaderETag, etag)
return nil
}
}
// appendUint appends n to dst and returns the extended dst.
func appendUint(dst []byte, n uint32) []byte {
var b [20]byte
buf := b[:]
i := len(buf)
var q uint32
for n >= 10 {
i--
q = n / 10
buf[i] = '0' + byte(n-q*10)
n = q
}
i--
buf[i] = '0' + byte(n)
dst = append(dst, buf[i:]...)
return dst
}
================================================
FILE: middleware/etag/etag_test.go
================================================
package etag
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
// go test -run Test_ETag_Next
func Test_ETag_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
// go test -run Test_ETag_SkipError
func Test_ETag_SkipError(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(_ fiber.Ctx) error {
return fiber.ErrForbidden
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusForbidden, resp.StatusCode)
}
// go test -run Test_ETag_NotStatusOK
func Test_ETag_NotStatusOK(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusCreated)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusCreated, resp.StatusCode)
}
// go test -run Test_ETag_NoBody
func Test_ETag_NoBody(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(_ fiber.Ctx) error {
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
// go test -run Test_ETag_NewEtag
func Test_ETag_NewEtag(t *testing.T) {
t.Parallel()
t.Run("without HeaderIfNoneMatch", func(t *testing.T) {
t.Parallel()
testETagNewEtag(t, false, false)
})
t.Run("with HeaderIfNoneMatch and not matched", func(t *testing.T) {
t.Parallel()
testETagNewEtag(t, true, false)
})
t.Run("with HeaderIfNoneMatch and matched", func(t *testing.T) {
t.Parallel()
testETagNewEtag(t, true, true)
})
}
func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
t.Helper()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
if headerIfNoneMatch {
etag := `"non-match"`
if matched {
etag = `"13-1831710635"`
}
req.Header.Set(fiber.HeaderIfNoneMatch, etag)
}
resp, err := app.Test(req)
require.NoError(t, err)
if !headerIfNoneMatch || !matched {
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, `"13-1831710635"`, resp.Header.Get(fiber.HeaderETag))
return
}
if matched {
require.Equal(t, fiber.StatusNotModified, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Empty(t, b)
}
}
// go test -run Test_ETag_WeakEtag
func Test_ETag_WeakEtag(t *testing.T) {
t.Parallel()
t.Run("without HeaderIfNoneMatch", func(t *testing.T) {
t.Parallel()
testETagWeakEtag(t, false, false)
})
t.Run("with HeaderIfNoneMatch and not matched", func(t *testing.T) {
t.Parallel()
testETagWeakEtag(t, true, false)
})
t.Run("with HeaderIfNoneMatch and matched", func(t *testing.T) {
t.Parallel()
testETagWeakEtag(t, true, true)
})
}
func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
t.Helper()
app := fiber.New()
app.Use(New(Config{Weak: true}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
if headerIfNoneMatch {
etag := `W/"non-match"`
if matched {
etag = `W/"13-1831710635"`
}
req.Header.Set(fiber.HeaderIfNoneMatch, etag)
}
resp, err := app.Test(req)
require.NoError(t, err)
if !headerIfNoneMatch || !matched {
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, `W/"13-1831710635"`, resp.Header.Get(fiber.HeaderETag))
return
}
if matched {
require.Equal(t, fiber.StatusNotModified, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Empty(t, b)
}
}
// go test -run Test_ETag_CustomEtag
func Test_ETag_CustomEtag(t *testing.T) {
t.Parallel()
t.Run("without HeaderIfNoneMatch", func(t *testing.T) {
t.Parallel()
testETagCustomEtag(t, false, false)
})
t.Run("with HeaderIfNoneMatch and not matched", func(t *testing.T) {
t.Parallel()
testETagCustomEtag(t, true, false)
})
t.Run("with HeaderIfNoneMatch and matched", func(t *testing.T) {
t.Parallel()
testETagCustomEtag(t, true, true)
})
}
func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine
t.Helper()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderETag, `"custom"`)
if bytes.Equal(c.Request().Header.Peek(fiber.HeaderIfNoneMatch), []byte(`"custom"`)) {
return c.SendStatus(fiber.StatusNotModified)
}
return c.SendString("Hello, World!")
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
if headerIfNoneMatch {
etag := `"non-match"`
if matched {
etag = `"custom"`
}
req.Header.Set(fiber.HeaderIfNoneMatch, etag)
}
resp, err := app.Test(req)
require.NoError(t, err)
if !headerIfNoneMatch || !matched {
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, `"custom"`, resp.Header.Get(fiber.HeaderETag))
return
}
if matched {
require.Equal(t, fiber.StatusNotModified, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Empty(t, b)
}
}
// go test -run Test_ETag_CustomEtagPut
func Test_ETag_CustomEtagPut(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Put("/", func(c fiber.Ctx) error {
c.Set(fiber.HeaderETag, `"custom"`)
if !bytes.Equal(c.Request().Header.Peek(fiber.HeaderIfMatch), []byte(`"custom"`)) {
return c.SendStatus(fiber.StatusPreconditionFailed)
}
return c.SendString("Hello, World!")
})
req := httptest.NewRequest(fiber.MethodPut, "/", http.NoBody)
req.Header.Set(fiber.HeaderIfMatch, `"non-match"`)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusPreconditionFailed, resp.StatusCode)
}
// go test -v -run=^$ -bench=Benchmark_Etag -benchmem -count=4
func Benchmark_Etag(b *testing.B) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
b.ReportAllocs()
for b.Loop() {
h(fctx)
}
require.Equal(b, 200, fctx.Response.Header.StatusCode())
require.Equal(b, `"13-1831710635"`, string(fctx.Response.Header.Peek(fiber.HeaderETag)))
}
================================================
FILE: middleware/expvar/config.go
================================================
package expvar
import (
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
}
var ConfigDefault = Config{
Next: nil,
}
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
return cfg
}
================================================
FILE: middleware/expvar/expvar.go
================================================
package expvar
import (
"strings"
"github.com/gofiber/fiber/v3"
"github.com/valyala/fasthttp/expvarhandler"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
path := c.Path()
// We are only interested in /debug/vars routes
if len(path) < 11 || !strings.HasPrefix(path, "/debug/vars") {
return c.Next()
}
if path == "/debug/vars" {
expvarhandler.ExpvarHandler(c.RequestCtx())
return nil
}
return c.Redirect().To("/debug/vars")
}
}
================================================
FILE: middleware/expvar/expvar_test.go
================================================
package expvar
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
func Test_Non_Expvar_Path(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "escaped", string(b))
}
func Test_Expvar_Index(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
require.Equal(t, fiber.MIMEApplicationJSONCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.True(t, bytes.Contains(b, []byte("cmdline")))
require.True(t, bytes.Contains(b, []byte("memstat")))
}
func Test_Expvar_Filter(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars?r=cmd", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
require.Equal(t, fiber.MIMEApplicationJSONCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.True(t, bytes.Contains(b, []byte("cmdline")))
require.False(t, bytes.Contains(b, []byte("memstat")))
}
func Test_Expvar_Other_Path(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars/303", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusSeeOther, resp.StatusCode)
}
// go test -run Test_Expvar_Next
func Test_Expvar_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars", http.NoBody))
require.NoError(t, err)
require.Equal(t, 404, resp.StatusCode)
}
================================================
FILE: middleware/favicon/config.go
================================================
package favicon
import (
"io/fs"
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// FileSystem is an optional alternate filesystem to search for the favicon in.
// An example of this could be an embedded or network filesystem
//
// Optional. Default: nil
FileSystem fs.FS `json:"-"`
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// File holds the path to an actual favicon that will be cached
//
// Optional. Default: ""
File string `json:"file"`
// URL for favicon handler
//
// Optional. Default: "/favicon.ico"
URL string `json:"url"`
// CacheControl defines how the Cache-Control header in the response should be set
//
// Optional. Default: "public, max-age=31536000"
CacheControl string `json:"cache_control"`
// Raw data of the favicon file
//
// Optional. Default: nil
Data []byte `json:"-"`
// MaxBytes limits the maximum size of the cached favicon asset.
//
// Optional. Default: 1048576
MaxBytes int64 `json:"max_bytes"`
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
File: "",
URL: fPath,
CacheControl: "public, max-age=31536000",
MaxBytes: 1024 * 1024,
}
func configDefault(config ...Config) Config {
if len(config) == 0 {
return ConfigDefault
}
cfg := config[0]
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.URL == "" {
cfg.URL = ConfigDefault.URL
}
if cfg.File == "" {
cfg.File = ConfigDefault.File
}
if cfg.CacheControl == "" {
cfg.CacheControl = ConfigDefault.CacheControl
}
if cfg.MaxBytes <= 0 {
cfg.MaxBytes = ConfigDefault.MaxBytes
}
return cfg
}
================================================
FILE: middleware/favicon/favicon.go
================================================
package favicon
import (
"fmt"
"io"
"io/fs"
"os"
"strconv"
"github.com/gofiber/fiber/v3"
)
const (
fPath = "/favicon.ico"
hType = "image/x-icon"
hAllow = "GET, HEAD, OPTIONS"
hZero = "0"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)
// Load iconData if provided
var (
err error
iconData []byte
iconLenHeader string
iconLen int
f fs.File
)
if cfg.Data != nil {
// use the provided favicon data
iconData = cfg.Data
iconLenHeader = strconv.Itoa(len(cfg.Data))
iconLen = len(cfg.Data)
} else if cfg.File != "" {
// read from configured filesystem if present
if cfg.FileSystem != nil {
f, err = cfg.FileSystem.Open(cfg.File)
if err != nil {
panic(err)
}
defer func() {
_ = f.Close() //nolint:errcheck // not needed
}()
if iconData, err = readLimited(f, cfg.MaxBytes); err != nil {
panic(err)
}
} else {
f, err = os.Open(cfg.File)
if err != nil {
panic(err)
}
defer func() {
_ = f.Close() //nolint:errcheck // not needed
}()
if iconData, err = readLimited(f, cfg.MaxBytes); err != nil {
panic(err)
}
}
iconLenHeader = strconv.Itoa(len(iconData))
iconLen = len(iconData)
}
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Only respond to favicon requests
if c.Path() != cfg.URL {
return c.Next()
}
// Only allow GET, HEAD and OPTIONS requests
if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead {
if c.Method() != fiber.MethodOptions {
c.Status(fiber.StatusMethodNotAllowed)
} else {
c.Status(fiber.StatusOK)
}
c.Set(fiber.HeaderAllow, hAllow)
c.Set(fiber.HeaderContentLength, hZero)
return nil
}
// Serve cached favicon
if iconLen > 0 {
c.Set(fiber.HeaderContentLength, iconLenHeader)
c.Set(fiber.HeaderContentType, hType)
c.Set(fiber.HeaderCacheControl, cfg.CacheControl)
return c.Status(fiber.StatusOK).Send(iconData)
}
return c.SendStatus(fiber.StatusNoContent)
}
}
func readLimited(reader io.Reader, maxBytes int64) ([]byte, error) {
limit := maxBytes + 1
data, err := io.ReadAll(io.LimitReader(reader, limit))
if err != nil {
return nil, fmt.Errorf("favicon: read limited: %w", err)
}
if int64(len(data)) > maxBytes {
return nil, fmt.Errorf("favicon: file size exceeds max bytes %d", maxBytes)
}
return data, nil
}
================================================
FILE: middleware/favicon/favicon_test.go
================================================
package favicon
import (
"bytes"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v3"
)
// go test -run Test_Middleware_Favicon
func Test_Middleware_Favicon(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(_ fiber.Ctx) error {
return nil
})
// Skip Favicon middleware
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusNoContent, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodOptions, "/favicon.ico", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodPut, "/favicon.ico", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusMethodNotAllowed, resp.StatusCode, "Status code")
require.Equal(t, "GET, HEAD, OPTIONS", resp.Header.Get(fiber.HeaderAllow))
}
// go test -run Test_Middleware_Favicon_Not_Found
func Test_Middleware_Favicon_Not_Found(t *testing.T) {
t.Parallel()
defer func() {
if err := recover(); err == nil {
t.Error("should catch panic")
return
}
}()
fiber.New().Use(New(Config{
File: "non-exist.ico",
}))
}
// go test -run Test_Middleware_Favicon_MaxBytes
func Test_Middleware_Favicon_MaxBytes(t *testing.T) {
t.Parallel()
defer func() {
if err := recover(); err == nil {
t.Error("should catch panic")
}
}()
dir := t.TempDir()
path := dir + "/favicon.ico"
err := os.WriteFile(path, bytes.Repeat([]byte("a"), 11), 0o600)
require.NoError(t, err)
fiber.New().Use(New(Config{
File: path,
MaxBytes: 10,
}))
}
// go test -run Test_Middleware_Favicon_MaxBytes_FileSystem
func Test_Middleware_Favicon_MaxBytes_FileSystem(t *testing.T) {
t.Parallel()
defer func() {
if err := recover(); err == nil {
t.Error("should catch panic")
}
}()
dir := t.TempDir()
path := dir + "/favicon.ico"
err := os.WriteFile(path, bytes.Repeat([]byte("a"), 11), 0o600)
require.NoError(t, err)
fiber.New().Use(New(Config{
File: "favicon.ico",
FileSystem: os.DirFS(dir),
MaxBytes: 10,
}))
}
// go test -run Test_Middleware_Favicon_Found
func Test_Middleware_Favicon_Found(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
File: "../../.github/testdata/favicon.ico",
}))
app.Get("/", func(_ fiber.Ctx) error {
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
require.Equal(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
require.Equal(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
}
// go test -run Test_Custom_Favicon_Url
func Test_Custom_Favicon_URL(t *testing.T) {
app := fiber.New()
const customURL = "/favicon.svg"
app.Use(New(Config{
File: "../../.github/testdata/favicon.ico",
URL: customURL,
}))
app.Get("/", func(_ fiber.Ctx) error {
return nil
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, customURL, http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
require.Equal(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
}
// go test -run Test_Custom_Favicon_Data
func Test_Custom_Favicon_Data(t *testing.T) {
data, err := os.ReadFile("../../.github/testdata/favicon.ico")
require.NoError(t, err)
app := fiber.New()
app.Use(New(Config{
Data: data,
}))
app.Get("/", func(_ fiber.Ctx) error {
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
require.Equal(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
require.Equal(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
}
// go test -run Test_Middleware_Favicon_FileSystem
func Test_Middleware_Favicon_FileSystem(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
File: "favicon.ico",
FileSystem: os.DirFS("../../.github/testdata"),
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
require.Equal(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
require.Equal(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
}
// go test -run Test_Middleware_Favicon_CacheControl
func Test_Middleware_Favicon_CacheControl(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
CacheControl: "public, max-age=100",
File: "../../.github/testdata/favicon.ico",
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code")
require.Equal(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType))
require.Equal(t, "public, max-age=100", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
}
// go test -v -run=^$ -bench=Benchmark_Middleware_Favicon -benchmem -count=4
func Benchmark_Middleware_Favicon(b *testing.B) {
app := fiber.New()
app.Use(New())
app.Get("/", func(_ fiber.Ctx) error {
return nil
})
handler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.SetRequestURI("/")
b.ReportAllocs()
for b.Loop() {
handler(c)
}
}
// go test -run Test_Favicon_Next
func Test_Favicon_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
================================================
FILE: middleware/healthcheck/config.go
================================================
package healthcheck
import (
"github.com/gofiber/fiber/v3"
)
// Config defines the configuration options for the healthcheck middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true. If this function returns true
// and no other handlers are defined for the route, Fiber will return a status 404 Not Found, since
// no other handlers were defined to return a different status.
//
// Optional. Default: nil
Next func(fiber.Ctx) bool
// Probe is executed to determine the current health state. It can be used for liveness,
// readiness or startup checks. Returning true indicates the application is healthy.
//
// Optional. Default: func(c fiber.Ctx) bool { return true }
Probe func(fiber.Ctx) bool
}
const (
// LivenessEndpoint is the conventional path for a liveness check.
// Register the middleware on this path to expose it.
LivenessEndpoint = "/livez"
// ReadinessEndpoint is the conventional path for a readiness check.
// Register the middleware on this path to expose it.
ReadinessEndpoint = "/readyz"
// StartupEndpoint is the conventional path for a startup check.
// Register the middleware on this path to expose it.
StartupEndpoint = "/startupz"
)
func defaultProbe(_ fiber.Ctx) bool { return true }
// ConfigDefault is the default configuration.
var ConfigDefault = Config{
Next: nil,
Probe: defaultProbe,
}
func configDefault(config ...Config) Config {
if len(config) < 1 {
return ConfigDefault
}
cfg := config[0]
if cfg.Probe == nil {
cfg.Probe = ConfigDefault.Probe
}
return cfg
}
================================================
FILE: middleware/healthcheck/healthcheck.go
================================================
package healthcheck
import (
"github.com/gofiber/fiber/v3"
)
// New returns a health-check handler that responds based on the provided
// configuration.
func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
if c.Method() != fiber.MethodGet {
return c.Next()
}
if cfg.Probe(c) {
return c.SendStatus(fiber.StatusOK)
}
return c.SendStatus(fiber.StatusServiceUnavailable)
}
}
================================================
FILE: middleware/healthcheck/healthcheck_test.go
================================================
package healthcheck
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func shouldGiveStatus(t *testing.T, app *fiber.App, path string, expectedStatus int) {
t.Helper()
req, err := app.Test(httptest.NewRequest(fiber.MethodGet, path, http.NoBody))
require.NoError(t, err)
require.Equal(t, expectedStatus, req.StatusCode, "path: "+path+" should match "+strconv.Itoa(expectedStatus))
}
func shouldGiveOK(t *testing.T, app *fiber.App, path string) {
t.Helper()
shouldGiveStatus(t, app, path, fiber.StatusOK)
}
func shouldGiveNotFound(t *testing.T, app *fiber.App, path string) {
t.Helper()
shouldGiveStatus(t, app, path, fiber.StatusNotFound)
}
func Test_HealthCheck_Strict_Routing_Default(t *testing.T) {
t.Parallel()
app := fiber.New(fiber.Config{
StrictRouting: true,
})
app.Get(LivenessEndpoint, New())
app.Get(ReadinessEndpoint, New())
app.Get(StartupEndpoint, New())
shouldGiveOK(t, app, "/readyz")
shouldGiveOK(t, app, "/livez")
shouldGiveOK(t, app, "/startupz")
shouldGiveNotFound(t, app, "/readyz/")
shouldGiveNotFound(t, app, "/livez/")
shouldGiveNotFound(t, app, "/startupz/")
shouldGiveNotFound(t, app, "/notDefined/readyz")
shouldGiveNotFound(t, app, "/notDefined/livez")
shouldGiveNotFound(t, app, "/notDefined/startupz")
}
func Test_HealthCheck_Default(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get(LivenessEndpoint, New())
app.Get(ReadinessEndpoint, New())
app.Get(StartupEndpoint, New())
shouldGiveOK(t, app, "/readyz")
shouldGiveOK(t, app, "/livez")
shouldGiveOK(t, app, "/startupz")
shouldGiveOK(t, app, "/readyz/")
shouldGiveOK(t, app, "/livez/")
shouldGiveOK(t, app, "/startupz/")
shouldGiveNotFound(t, app, "/notDefined/readyz")
shouldGiveNotFound(t, app, "/notDefined/livez")
shouldGiveNotFound(t, app, "/notDefined/startupz")
}
func Test_HealthCheck_Custom(t *testing.T) {
t.Parallel()
app := fiber.New()
c1 := make(chan struct{}, 1)
app.Get("/live", New(Config{
Probe: func(_ fiber.Ctx) bool {
return true
},
}))
app.Get("/ready", New(Config{
Probe: func(_ fiber.Ctx) bool {
select {
case <-c1:
return true
default:
return false
}
},
}))
app.Get(StartupEndpoint, New(Config{
Probe: func(_ fiber.Ctx) bool {
return false
},
}))
// Setup custom liveness and readiness probes to simulate application health status
// Live should return 200 with GET request
shouldGiveOK(t, app, "/live")
// Live should return 404 with POST request
req, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/live", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusMethodNotAllowed, req.StatusCode)
// Ready should return 404 with POST request
req, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/ready", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusMethodNotAllowed, req.StatusCode)
// Ready should return 503 with GET request before the channel is closed
shouldGiveStatus(t, app, "/ready", fiber.StatusServiceUnavailable)
shouldGiveStatus(t, app, "/startupz", fiber.StatusServiceUnavailable)
// Ready should return 200 with GET request after the channel is closed
c1 <- struct{}{}
shouldGiveOK(t, app, "/ready")
}
func Test_HealthCheck_Custom_Nested(t *testing.T) {
t.Parallel()
app := fiber.New()
c1 := make(chan struct{}, 1)
app.Get("/probe/live", New(Config{
Probe: func(_ fiber.Ctx) bool {
return true
},
}))
app.Get("/probe/ready", New(Config{
Probe: func(_ fiber.Ctx) bool {
select {
case <-c1:
return true
default:
return false
}
},
}))
// Testing custom health check endpoints with nested paths
shouldGiveOK(t, app, "/probe/live")
shouldGiveStatus(t, app, "/probe/ready", fiber.StatusServiceUnavailable)
shouldGiveOK(t, app, "/probe/live/")
shouldGiveStatus(t, app, "/probe/ready/", fiber.StatusServiceUnavailable)
shouldGiveNotFound(t, app, "/probe/livez")
shouldGiveNotFound(t, app, "/probe/readyz")
shouldGiveNotFound(t, app, "/probe/livez/")
shouldGiveNotFound(t, app, "/probe/readyz/")
shouldGiveNotFound(t, app, "/livez")
shouldGiveNotFound(t, app, "/readyz")
shouldGiveNotFound(t, app, "/readyz/")
shouldGiveNotFound(t, app, "/livez/")
c1 <- struct{}{}
shouldGiveOK(t, app, "/probe/ready")
c1 <- struct{}{}
shouldGiveOK(t, app, "/probe/ready/")
}
func Test_HealthCheck_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
checker := New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
})
app.Get(LivenessEndpoint, checker)
app.Get(ReadinessEndpoint, checker)
app.Get(StartupEndpoint, checker)
// This should give not found since there are no other handlers to execute
// so it's like the route isn't defined at all
shouldGiveNotFound(t, app, "/readyz")
shouldGiveNotFound(t, app, "/livez")
shouldGiveNotFound(t, app, "/startupz")
}
func Benchmark_HealthCheck(b *testing.B) {
app := fiber.New()
app.Get(LivenessEndpoint, New())
app.Get(ReadinessEndpoint, New())
app.Get(StartupEndpoint, New())
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/livez")
b.ReportAllocs()
for b.Loop() {
h(fctx)
}
require.Equal(b, fiber.StatusOK, fctx.Response.Header.StatusCode())
}
func Benchmark_HealthCheck_Parallel(b *testing.B) {
app := fiber.New()
app.Get(LivenessEndpoint, New())
app.Get(ReadinessEndpoint, New())
app.Get(StartupEndpoint, New())
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/livez")
for pb.Next() {
h(fctx)
}
})
}
================================================
FILE: middleware/helmet/config.go
================================================
package helmet
import (
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip middleware.
// Optional. Default: nil
Next func(fiber.Ctx) bool
// XSSProtection
// Optional. Default value "0".
XSSProtection string
// ContentTypeNosniff
// Optional. Default value "nosniff".
ContentTypeNosniff string
// XFrameOptions
// Optional. Default value "SAMEORIGIN".
// Possible values: "SAMEORIGIN", "DENY", "ALLOW-FROM uri"
XFrameOptions string
// ContentSecurityPolicy
// Optional. Default value "".
ContentSecurityPolicy string
// ReferrerPolicy
// Optional. Default value "no-referrer".
ReferrerPolicy string
// Permissions-Policy
// Optional. Default value "".
PermissionPolicy string
// Cross-Origin-Embedder-Policy
// Optional. Default value "require-corp".
CrossOriginEmbedderPolicy string
// Cross-Origin-Opener-Policy
// Optional. Default value "same-origin".
CrossOriginOpenerPolicy string
// Cross-Origin-Resource-Policy
// Optional. Default value "same-origin".
CrossOriginResourcePolicy string
// Origin-Agent-Cluster
// Optional. Default value "?1".
OriginAgentCluster string
// X-DNS-Prefetch-Control
// Optional. Default value "off".
XDNSPrefetchControl string
// X-Download-Options
// Optional. Default value "noopen".
XDownloadOptions string
// X-Permitted-Cross-Domain-Policies
// Optional. Default value "none".
XPermittedCrossDomain string
// HSTSMaxAge
// Optional. Default value 0.
HSTSMaxAge int
// HSTSExcludeSubdomains
// Optional. Default value false.
HSTSExcludeSubdomains bool
// CSPReportOnly
// Optional. Default value false.
CSPReportOnly bool
// HSTSPreloadEnabled
// Optional. Default value false.
HSTSPreloadEnabled bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
XSSProtection: "0",
ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN",
ReferrerPolicy: "no-referrer",
CrossOriginEmbedderPolicy: "require-corp",
CrossOriginOpenerPolicy: "same-origin",
CrossOriginResourcePolicy: "same-origin",
OriginAgentCluster: "?1",
XDNSPrefetchControl: "off",
XDownloadOptions: "noopen",
XPermittedCrossDomain: "none",
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.XSSProtection == "" {
cfg.XSSProtection = ConfigDefault.XSSProtection
}
if cfg.ContentTypeNosniff == "" {
cfg.ContentTypeNosniff = ConfigDefault.ContentTypeNosniff
}
if cfg.XFrameOptions == "" {
cfg.XFrameOptions = ConfigDefault.XFrameOptions
}
if cfg.ReferrerPolicy == "" {
cfg.ReferrerPolicy = ConfigDefault.ReferrerPolicy
}
if cfg.CrossOriginEmbedderPolicy == "" {
cfg.CrossOriginEmbedderPolicy = ConfigDefault.CrossOriginEmbedderPolicy
}
if cfg.CrossOriginOpenerPolicy == "" {
cfg.CrossOriginOpenerPolicy = ConfigDefault.CrossOriginOpenerPolicy
}
if cfg.CrossOriginResourcePolicy == "" {
cfg.CrossOriginResourcePolicy = ConfigDefault.CrossOriginResourcePolicy
}
if cfg.OriginAgentCluster == "" {
cfg.OriginAgentCluster = ConfigDefault.OriginAgentCluster
}
if cfg.XDNSPrefetchControl == "" {
cfg.XDNSPrefetchControl = ConfigDefault.XDNSPrefetchControl
}
if cfg.XDownloadOptions == "" {
cfg.XDownloadOptions = ConfigDefault.XDownloadOptions
}
if cfg.XPermittedCrossDomain == "" {
cfg.XPermittedCrossDomain = ConfigDefault.XPermittedCrossDomain
}
return cfg
}
================================================
FILE: middleware/helmet/helmet.go
================================================
package helmet
import (
"fmt"
"github.com/gofiber/fiber/v3"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Init config
cfg := configDefault(config...)
// Return middleware handler
return func(c fiber.Ctx) error {
// Next request to skip middleware
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Set headers
if cfg.XSSProtection != "" {
c.Set(fiber.HeaderXXSSProtection, cfg.XSSProtection)
}
if cfg.ContentTypeNosniff != "" {
c.Set(fiber.HeaderXContentTypeOptions, cfg.ContentTypeNosniff)
}
if cfg.XFrameOptions != "" {
c.Set(fiber.HeaderXFrameOptions, cfg.XFrameOptions)
}
if cfg.CrossOriginEmbedderPolicy != "" {
c.Set("Cross-Origin-Embedder-Policy", cfg.CrossOriginEmbedderPolicy)
}
if cfg.CrossOriginOpenerPolicy != "" {
c.Set("Cross-Origin-Opener-Policy", cfg.CrossOriginOpenerPolicy)
}
if cfg.CrossOriginResourcePolicy != "" {
c.Set("Cross-Origin-Resource-Policy", cfg.CrossOriginResourcePolicy)
}
if cfg.OriginAgentCluster != "" {
c.Set("Origin-Agent-Cluster", cfg.OriginAgentCluster)
}
if cfg.ReferrerPolicy != "" {
c.Set("Referrer-Policy", cfg.ReferrerPolicy)
}
if cfg.XDNSPrefetchControl != "" {
c.Set("X-DNS-Prefetch-Control", cfg.XDNSPrefetchControl)
}
if cfg.XDownloadOptions != "" {
c.Set("X-Download-Options", cfg.XDownloadOptions)
}
if cfg.XPermittedCrossDomain != "" {
c.Set("X-Permitted-Cross-Domain-Policies", cfg.XPermittedCrossDomain)
}
// Handle HSTS headers
if c.Protocol() == "https" && cfg.HSTSMaxAge != 0 {
subdomains := ""
if !cfg.HSTSExcludeSubdomains {
subdomains = "; includeSubDomains"
}
if cfg.HSTSPreloadEnabled {
subdomains += "; preload"
}
c.Set(fiber.HeaderStrictTransportSecurity, fmt.Sprintf("max-age=%d%s", cfg.HSTSMaxAge, subdomains))
}
// Handle Content-Security-Policy headers
if cfg.ContentSecurityPolicy != "" {
if cfg.CSPReportOnly {
c.Set(fiber.HeaderContentSecurityPolicyReportOnly, cfg.ContentSecurityPolicy)
} else {
c.Set(fiber.HeaderContentSecurityPolicy, cfg.ContentSecurityPolicy)
}
}
// Handle Permissions-Policy headers
if cfg.PermissionPolicy != "" {
c.Set(fiber.HeaderPermissionsPolicy, cfg.PermissionPolicy)
}
return c.Next()
}
}
================================================
FILE: middleware/helmet/helmet_test.go
================================================
package helmet
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func Test_Default(t *testing.T) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, "0", resp.Header.Get(fiber.HeaderXXSSProtection))
require.Equal(t, "nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions))
require.Equal(t, "SAMEORIGIN", resp.Header.Get(fiber.HeaderXFrameOptions))
require.Empty(t, resp.Header.Get(fiber.HeaderContentSecurityPolicy))
require.Equal(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy))
require.Empty(t, resp.Header.Get(fiber.HeaderPermissionsPolicy))
require.Equal(t, "require-corp", resp.Header.Get("Cross-Origin-Embedder-Policy"))
require.Equal(t, "same-origin", resp.Header.Get("Cross-Origin-Opener-Policy"))
require.Equal(t, "same-origin", resp.Header.Get("Cross-Origin-Resource-Policy"))
require.Equal(t, "?1", resp.Header.Get("Origin-Agent-Cluster"))
require.Equal(t, "off", resp.Header.Get("X-DNS-Prefetch-Control"))
require.Equal(t, "noopen", resp.Header.Get("X-Download-Options"))
require.Equal(t, "none", resp.Header.Get("X-Permitted-Cross-Domain-Policies"))
}
func Test_CustomValues_AllHeaders(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
// Custom values for all headers
XSSProtection: "0",
ContentTypeNosniff: "custom-nosniff",
XFrameOptions: "DENY",
HSTSExcludeSubdomains: true,
ContentSecurityPolicy: "default-src 'none'",
CSPReportOnly: true,
HSTSPreloadEnabled: true,
ReferrerPolicy: "origin",
PermissionPolicy: "geolocation=(self)",
CrossOriginEmbedderPolicy: "custom-value",
CrossOriginOpenerPolicy: "custom-value",
CrossOriginResourcePolicy: "custom-value",
OriginAgentCluster: "custom-value",
XDNSPrefetchControl: "custom-control",
XDownloadOptions: "custom-options",
XPermittedCrossDomain: "custom-policies",
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
// Assertions for custom header values
require.Equal(t, "0", resp.Header.Get(fiber.HeaderXXSSProtection))
require.Equal(t, "custom-nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions))
require.Equal(t, "DENY", resp.Header.Get(fiber.HeaderXFrameOptions))
require.Equal(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicyReportOnly))
require.Equal(t, "origin", resp.Header.Get(fiber.HeaderReferrerPolicy))
require.Equal(t, "geolocation=(self)", resp.Header.Get(fiber.HeaderPermissionsPolicy))
require.Equal(t, "custom-value", resp.Header.Get("Cross-Origin-Embedder-Policy"))
require.Equal(t, "custom-value", resp.Header.Get("Cross-Origin-Opener-Policy"))
require.Equal(t, "custom-value", resp.Header.Get("Cross-Origin-Resource-Policy"))
require.Equal(t, "custom-value", resp.Header.Get("Origin-Agent-Cluster"))
require.Equal(t, "custom-control", resp.Header.Get("X-DNS-Prefetch-Control"))
require.Equal(t, "custom-options", resp.Header.Get("X-Download-Options"))
require.Equal(t, "custom-policies", resp.Header.Get("X-Permitted-Cross-Domain-Policies"))
}
func Test_RealWorldValues_AllHeaders(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
// Real-world values for all headers
XSSProtection: "0",
ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN",
HSTSExcludeSubdomains: false,
ContentSecurityPolicy: "default-src 'self';base-uri 'self';font-src 'self' https: data:;form-action 'self';frame-ancestors 'self';img-src 'self' data:;object-src 'none';script-src 'self';script-src-attr 'none';style-src 'self' https: 'unsafe-inline';upgrade-insecure-requests",
CSPReportOnly: false,
HSTSPreloadEnabled: true,
ReferrerPolicy: "no-referrer",
PermissionPolicy: "geolocation=(self)",
CrossOriginEmbedderPolicy: "require-corp",
CrossOriginOpenerPolicy: "same-origin",
CrossOriginResourcePolicy: "same-origin",
OriginAgentCluster: "?1",
XDNSPrefetchControl: "off",
XDownloadOptions: "noopen",
XPermittedCrossDomain: "none",
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
// Assertions for real-world header values
require.Equal(t, "0", resp.Header.Get(fiber.HeaderXXSSProtection))
require.Equal(t, "nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions))
require.Equal(t, "SAMEORIGIN", resp.Header.Get(fiber.HeaderXFrameOptions))
require.Equal(t, "default-src 'self';base-uri 'self';font-src 'self' https: data:;form-action 'self';frame-ancestors 'self';img-src 'self' data:;object-src 'none';script-src 'self';script-src-attr 'none';style-src 'self' https: 'unsafe-inline';upgrade-insecure-requests", resp.Header.Get(fiber.HeaderContentSecurityPolicy))
require.Equal(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy))
require.Equal(t, "geolocation=(self)", resp.Header.Get(fiber.HeaderPermissionsPolicy))
require.Equal(t, "require-corp", resp.Header.Get("Cross-Origin-Embedder-Policy"))
require.Equal(t, "same-origin", resp.Header.Get("Cross-Origin-Opener-Policy"))
require.Equal(t, "same-origin", resp.Header.Get("Cross-Origin-Resource-Policy"))
require.Equal(t, "?1", resp.Header.Get("Origin-Agent-Cluster"))
require.Equal(t, "off", resp.Header.Get("X-DNS-Prefetch-Control"))
require.Equal(t, "noopen", resp.Header.Get("X-Download-Options"))
require.Equal(t, "none", resp.Header.Get("X-Permitted-Cross-Domain-Policies"))
}
func Test_Next(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Next: func(ctx fiber.Ctx) bool {
return ctx.Path() == "/next"
},
ReferrerPolicy: "no-referrer",
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
app.Get("/next", func(c fiber.Ctx) error {
return c.SendString("Skipped!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/next", http.NoBody))
require.NoError(t, err)
require.Empty(t, resp.Header.Get(fiber.HeaderReferrerPolicy))
}
func Test_ContentSecurityPolicy(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
ContentSecurityPolicy: "default-src 'none'",
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicy))
}
func Test_ContentSecurityPolicyReportOnly(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
ContentSecurityPolicy: "default-src 'none'",
CSPReportOnly: true,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicyReportOnly))
require.Empty(t, resp.Header.Get(fiber.HeaderContentSecurityPolicy))
}
func Test_PermissionsPolicy(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
PermissionPolicy: "microphone=()",
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, "microphone=()", resp.Header.Get(fiber.HeaderPermissionsPolicy))
}
func Test_HSTSHeaders(t *testing.T) {
hstsAge := 60
app := fiber.New()
app.Use(New(Config{HSTSMaxAge: hstsAge}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
handler := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetProtocol("https")
handler(ctx)
require.Equal(t, "max-age=60; includeSubDomains", string(ctx.Response.Header.Peek(fiber.HeaderStrictTransportSecurity)))
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetProtocol("http")
handler(ctx)
require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderStrictTransportSecurity)))
}
func Test_HSTSExcludeSubdomainsAndPreload(t *testing.T) {
hstsAge := 31536000
app := fiber.New()
app.Use(New(Config{
HSTSMaxAge: hstsAge,
HSTSExcludeSubdomains: true,
HSTSPreloadEnabled: true,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
handler := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetProtocol("https")
handler(ctx)
require.Equal(t, "max-age=31536000; preload", string(ctx.Response.Header.Peek(fiber.HeaderStrictTransportSecurity)))
}
================================================
FILE: middleware/idempotency/config.go
================================================
package idempotency
import (
"errors"
"fmt"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/storage/memory"
)
var ErrInvalidIdempotencyKey = errors.New("invalid idempotency key")
// Config defines the config for middleware.
type Config struct {
// Lock locks an idempotency key.
//
// Optional. Default: an in-memory locker for this process only.
Lock Locker
// Storage stores response data by idempotency key.
//
// Optional. Default: an in-memory storage for this process only.
Storage fiber.Storage
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: a function which skips the middleware on safe HTTP request method.
Next func(c fiber.Ctx) bool
// KeyHeaderValidate defines a function to validate the syntax of the idempotency header.
//
// Optional. Default: a function which ensures the header is 36 characters long (the size of an UUID).
KeyHeaderValidate func(string) error
// KeyHeader is the name of the header that contains the idempotency key.
//
// Optional. Default: X-Idempotency-Key
KeyHeader string
// KeepResponseHeaders is a list of headers that should be kept from the original response.
//
// Optional. Default: nil (to keep all headers)
KeepResponseHeaders []string
// Lifetime is the maximum lifetime of an idempotency key.
//
// Optional. Default: 30 * time.Minute
Lifetime time.Duration
// DisableValueRedaction turns off masking idempotency keys in logs and errors when set to true.
//
// Optional. Default: false
DisableValueRedaction bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: func(c fiber.Ctx) bool {
// Skip middleware if the request was done using a safe HTTP method
return fiber.IsMethodSafe(c.Method())
},
Lifetime: 30 * time.Minute,
KeyHeader: "X-Idempotency-Key",
KeyHeaderValidate: func(k string) error {
if l, wl := len(k), 36; l != wl { // UUID length is 36 chars
return fmt.Errorf("%w: invalid length: %d != %d", ErrInvalidIdempotencyKey, l, wl)
}
return nil
},
KeepResponseHeaders: nil,
Lock: nil, // Set in configDefault so we don't allocate data here.
Storage: nil, // Set in configDefault so we don't allocate data here.
DisableValueRedaction: false,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
cfg := ConfigDefault
cfg.Lock = NewMemoryLock()
cfg.Storage = memory.New(memory.Config{
GCInterval: cfg.Lifetime / 2, // Half the lifetime interval
})
return cfg
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Lifetime.Nanoseconds() == 0 {
cfg.Lifetime = ConfigDefault.Lifetime
}
if cfg.KeyHeader == "" {
cfg.KeyHeader = ConfigDefault.KeyHeader
}
if cfg.KeyHeaderValidate == nil {
cfg.KeyHeaderValidate = ConfigDefault.KeyHeaderValidate
}
if cfg.KeepResponseHeaders != nil && len(cfg.KeepResponseHeaders) == 0 {
cfg.KeepResponseHeaders = ConfigDefault.KeepResponseHeaders
}
if cfg.Lock == nil {
cfg.Lock = NewMemoryLock()
}
if cfg.Storage == nil {
cfg.Storage = memory.New(memory.Config{
GCInterval: cfg.Lifetime / 2,
})
}
return cfg
}
================================================
FILE: middleware/idempotency/idempotency.go
================================================
package idempotency
import (
"fmt"
"github.com/gofiber/utils/v2"
utilsstrings "github.com/gofiber/utils/v2/strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
)
// Inspired by https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-idempotency-key-header-02
// and https://github.com/penguin-statistics/backend-next/blob/f2f7d5ba54fc8a58f168d153baa17b2ad4a14e45/internal/pkg/middlewares/idempotency.go
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
const (
localsKeyIsFromCache contextKey = iota //
localsKeyWasPutToCache
)
const redactedKey = "[redacted]"
// IsFromCache reports whether the middleware served the response from the
// cache for the current request.
func IsFromCache(c fiber.Ctx) bool {
return c.Locals(localsKeyIsFromCache) != nil
}
// WasPutToCache reports whether the middleware stored the response produced by
// the current request in the cache.
func WasPutToCache(c fiber.Ctx) bool {
val := c.Locals(localsKeyWasPutToCache)
if wasPut, ok := val.(bool); ok {
return wasPut
}
return val != nil
}
// New creates idempotency middleware that caches responses keyed by the
// configured idempotency header.
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
redactKeys := !cfg.DisableValueRedaction
maskKey := func(key string) string {
if redactKeys {
return redactedKey
}
return key
}
keepResponseHeadersMap := make(map[string]struct{}, len(cfg.KeepResponseHeaders))
for _, h := range cfg.KeepResponseHeaders {
// CopyString is needed because utils.ToLower uses UnsafeString
// and map keys must be immutable
keepResponseHeadersMap[utilsstrings.ToLower(h)] = struct{}{}
}
maybeWriteCachedResponse := func(c fiber.Ctx, key string) (bool, error) {
if val, err := cfg.Storage.GetWithContext(c, key); err != nil {
return false, fmt.Errorf("failed to read response: %w", err)
} else if val != nil {
var res response
if _, err := res.UnmarshalMsg(val); err != nil {
return false, fmt.Errorf("failed to unmarshal response: %w", err)
}
_ = c.Status(res.StatusCode)
for header, vals := range res.Headers {
for _, val := range vals {
c.RequestCtx().Response.Header.Add(header, val)
}
}
if len(res.Body) != 0 {
if err := c.Send(res.Body); err != nil {
return true, err
}
}
_ = c.Locals(localsKeyIsFromCache, true)
return true, nil
}
return false, nil
}
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Don't execute middleware if the idempotency key is empty
if c.Get(cfg.KeyHeader) == "" {
return c.Next()
}
// Validate key
key := utils.CopyString(c.Get(cfg.KeyHeader))
if err := cfg.KeyHeaderValidate(key); err != nil {
return err
}
// First-pass: if the idempotency key is in the storage, get and return the response
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
return fmt.Errorf("failed to write cached response at fastpath: %w", err)
} else if ok {
return nil
}
if err := cfg.Lock.Lock(key); err != nil {
return fmt.Errorf("failed to lock: %w", err)
}
defer func() {
if err := cfg.Lock.Unlock(key); err != nil {
log.Errorf("[IDEMPOTENCY] failed to unlock key %q: %v", maskKey(key), err)
}
}()
// Lock acquired. If the idempotency key now is in the storage, get and return the response
if ok, err := maybeWriteCachedResponse(c, key); err != nil {
return fmt.Errorf("failed to write cached response while locked: %w", err)
} else if ok {
return nil
}
// Execute the request handler
if err := c.Next(); err != nil {
// If the request handler returned an error, return it and skip idempotency
return err
}
// Construct response
res := &response{
StatusCode: c.Response().StatusCode(),
Body: c.Response().Body(),
}
{
headers := make(map[string][]string)
if err := c.Bind().RespHeader(headers); err != nil {
return fmt.Errorf("failed to bind to response headers: %w", err)
}
if cfg.KeepResponseHeaders == nil {
// Keep all
res.Headers = headers
} else {
// Filter
res.Headers = make(map[string][]string)
for h := range headers {
if _, ok := keepResponseHeadersMap[utilsstrings.ToLower(h)]; ok {
res.Headers[h] = headers[h]
}
}
}
}
bodyLimit := c.App().Config().BodyLimit
if bodyLimit > 0 && len(res.Body) > bodyLimit {
_ = c.Locals(localsKeyWasPutToCache, false)
return nil
}
// Marshal response
bs, err := res.MarshalMsg(nil)
if err != nil {
return fmt.Errorf("failed to marshal response: %w", err)
}
// Store response
if err := cfg.Storage.SetWithContext(c, key, bs, cfg.Lifetime); err != nil {
return fmt.Errorf("failed to save response: %w", err)
}
_ = c.Locals(localsKeyWasPutToCache, true)
return nil
}
}
================================================
FILE: middleware/idempotency/idempotency_test.go
================================================
package idempotency
import (
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/valyala/fasthttp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const validKey = "00000000-0000-0000-0000-000000000000"
// go test -run Test_Idempotency
func Test_Idempotency(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(func(c fiber.Ctx) error {
if err := c.Next(); err != nil {
return err
}
isMethodSafe := fiber.IsMethodSafe(c.Method())
isIdempotent := IsFromCache(c) || WasPutToCache(c)
hasReqHeader := c.Get("X-Idempotency-Key") != ""
if isMethodSafe {
if isIdempotent {
return errors.New("request with safe HTTP method should not be idempotent")
}
} else {
// Unsafe
if hasReqHeader {
if !isIdempotent {
return errors.New("request with unsafe HTTP method should be idempotent if X-Idempotency-Key request header is set")
}
} else if isIdempotent {
return errors.New("request with unsafe HTTP method should not be idempotent if X-Idempotency-Key request header is not set")
}
}
return nil
})
// Needs to be at least a second as the memory storage doesn't support shorter durations.
const lifetime = 2 * time.Second
app.Use(New(Config{
Lifetime: lifetime,
}))
nextCount := func() func() int {
var count int32
return func() int {
return int(atomic.AddInt32(&count, 1))
}
}()
app.Add([]string{
fiber.MethodGet,
fiber.MethodPost,
}, "/", func(c fiber.Ctx) error {
return c.SendString(strconv.Itoa(nextCount()))
})
app.Post("/slow", func(c fiber.Ctx) error {
time.Sleep(3 * lifetime)
return c.SendString(strconv.Itoa(nextCount()))
})
doReq := func(method, route, idempotencyKey string) string {
req := httptest.NewRequest(method, route, http.NoBody)
if idempotencyKey != "" {
req.Header.Set("X-Idempotency-Key", idempotencyKey)
}
resp, err := app.Test(req, fiber.TestConfig{
Timeout: 15 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode, string(body))
return string(body)
}
require.Equal(t, "1", doReq(fiber.MethodGet, "/", ""))
require.Equal(t, "2", doReq(fiber.MethodGet, "/", ""))
require.Equal(t, "3", doReq(fiber.MethodPost, "/", ""))
require.Equal(t, "4", doReq(fiber.MethodPost, "/", ""))
require.Equal(t, "5", doReq(fiber.MethodGet, "/", "00000000-0000-0000-0000-000000000000"))
require.Equal(t, "6", doReq(fiber.MethodGet, "/", "00000000-0000-0000-0000-000000000000"))
require.Equal(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
require.Equal(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
require.Equal(t, "8", doReq(fiber.MethodPost, "/", ""))
require.Equal(t, "9", doReq(fiber.MethodPost, "/", "11111111-1111-1111-1111-111111111111"))
require.Equal(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
time.Sleep(4 * lifetime)
require.Equal(t, "10", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
require.Equal(t, "10", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000"))
// Test raciness
{
var wg sync.WaitGroup
for range 100 {
wg.Go(func() {
assert.Equal(t, "11", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222"))
})
}
wg.Wait()
require.Equal(t, "11", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222"))
}
time.Sleep(3 * lifetime)
require.Equal(t, "12", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222"))
}
// go test -v -run=^$ -bench=Benchmark_Idempotency -benchmem -count=4
func Benchmark_Idempotency(b *testing.B) {
app := fiber.New()
// Needs to be at least a second as the memory storage doesn't support shorter durations.
const lifetime = 1 * time.Second
app.Use(New(Config{
Lifetime: lifetime,
}))
app.Post("/", func(_ fiber.Ctx) error {
return nil
})
h := app.Handler()
b.Run("hit", func(b *testing.B) {
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod(fiber.MethodPost)
c.Request.SetRequestURI("/")
c.Request.Header.Set("X-Idempotency-Key", "00000000-0000-0000-0000-000000000000")
b.ReportAllocs()
for b.Loop() {
h(c)
}
})
b.Run("skip", func(b *testing.B) {
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod(fiber.MethodPost)
c.Request.SetRequestURI("/")
b.ReportAllocs()
for b.Loop() {
h(c)
}
})
}
func Test_configDefault_defaults(t *testing.T) {
t.Parallel()
cfg := configDefault()
require.NotNil(t, cfg.Lock)
require.NotNil(t, cfg.Storage)
require.Equal(t, ConfigDefault.Lifetime, cfg.Lifetime)
require.Equal(t, ConfigDefault.KeyHeader, cfg.KeyHeader)
require.Nil(t, cfg.KeepResponseHeaders)
app := fiber.New()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
ctx := app.AcquireCtx(fctx)
require.True(t, cfg.Next(ctx))
app.ReleaseCtx(ctx)
fctx = &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodPost)
ctx = app.AcquireCtx(fctx)
require.False(t, cfg.Next(ctx))
app.ReleaseCtx(ctx)
require.NoError(t, cfg.KeyHeaderValidate(validKey))
require.Error(t, cfg.KeyHeaderValidate("short"))
}
func Test_configDefault_override(t *testing.T) {
t.Parallel()
l := &stubLock{}
s := &stubStorage{}
cfg := configDefault(Config{
Lifetime: 42 * time.Second,
KeyHeader: "Foo",
KeepResponseHeaders: []string{},
Lock: l,
Storage: s,
})
require.Equal(t, 42*time.Second, cfg.Lifetime)
require.Equal(t, "Foo", cfg.KeyHeader)
require.Nil(t, cfg.KeepResponseHeaders)
require.Equal(t, l, cfg.Lock)
require.Equal(t, s, cfg.Storage)
require.NotNil(t, cfg.Next)
require.NotNil(t, cfg.KeyHeaderValidate)
}
// helper to perform request
func do(app *fiber.App, req *http.Request) (resp *http.Response, body string) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming returned response and body payload for clarity
resp, err := app.Test(req, fiber.TestConfig{Timeout: 5 * time.Second})
if err != nil {
panic(err)
}
payload, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
return resp, string(payload)
}
func Test_New_NextSkip(t *testing.T) {
t.Parallel()
app := fiber.New()
var count int
app.Use(New(Config{Next: func(_ fiber.Ctx) bool { return true }}))
app.Post("/", func(c fiber.Ctx) error {
count++
return c.SendString(strconv.Itoa(count))
})
req := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
_, body1 := do(app, req)
req2 := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req2.Header.Set(ConfigDefault.KeyHeader, validKey)
_, body2 := do(app, req2)
require.Equal(t, "1", body1)
require.Equal(t, "2", body2)
}
func Test_New_InvalidKey(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Post("/", func(_ fiber.Ctx) error { return nil })
req := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req.Header.Set(ConfigDefault.KeyHeader, "bad")
resp, body := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, body, "invalid length")
}
func Test_New_StorageGetError(t *testing.T) {
t.Parallel()
app := fiber.New()
s := &stubStorage{getErr: errors.New("boom")}
app.Use(New(Config{Storage: s, Lock: &stubLock{}}))
app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") })
req := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, body, "failed to write cached response at fastpath")
}
func Test_New_UnmarshalError(t *testing.T) {
t.Parallel()
app := fiber.New()
s := &stubStorage{data: map[string][]byte{validKey: []byte("bad")}}
app.Use(New(Config{Storage: s, Lock: &stubLock{}}))
app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") })
req := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, body, "failed to write cached response at fastpath")
}
func Test_New_StoreRetrieve_FilterHeaders(t *testing.T) {
t.Parallel()
app := fiber.New()
s := &stubStorage{}
app.Use(New(Config{
Storage: s,
Lock: &stubLock{},
KeepResponseHeaders: []string{"Foo"},
}))
var count int
app.Post("/", func(c fiber.Ctx) error {
count++
c.Set("Foo", "foo")
c.Set("Bar", "bar")
return c.SendString(fmt.Sprintf("resp%d", count))
})
req := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, "resp1", body)
require.Equal(t, "foo", resp.Header.Get("Foo"))
require.Equal(t, "bar", resp.Header.Get("Bar"))
req2 := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req2.Header.Set(ConfigDefault.KeyHeader, validKey)
resp2, body2 := do(app, req2)
require.Equal(t, "resp1", body2)
require.Equal(t, "foo", resp2.Header.Get("Foo"))
require.Empty(t, resp2.Header.Get("Bar"))
require.Equal(t, 1, count)
require.Equal(t, 1, s.setCount)
}
func Test_New_SkipCache_WhenBodyTooLarge(t *testing.T) {
t.Parallel()
bodyLimit := 8
app := fiber.New(fiber.Config{BodyLimit: bodyLimit})
s := &stubStorage{}
var wasPut []bool
app.Use(func(c fiber.Ctx) error {
if err := c.Next(); err != nil {
return err
}
wasPut = append(wasPut, WasPutToCache(c))
return nil
})
app.Use(New(Config{Storage: s, Lock: &stubLock{}}))
var count int
oversized := strings.Repeat("a", bodyLimit+1)
app.Post("/", func(c fiber.Ctx) error {
count++
return c.SendString(oversized)
})
req := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp1, body1 := do(app, req)
require.Equal(t, fiber.StatusOK, resp1.StatusCode)
require.Equal(t, oversized, body1)
req2 := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req2.Header.Set(ConfigDefault.KeyHeader, validKey)
resp2, body2 := do(app, req2)
require.Equal(t, fiber.StatusOK, resp2.StatusCode)
require.Equal(t, oversized, body2)
require.Equal(t, 2, count)
require.Equal(t, 0, s.setCount)
require.Len(t, wasPut, 2)
require.False(t, wasPut[0])
require.False(t, wasPut[1])
}
func Test_New_HandlerError(t *testing.T) {
t.Parallel()
app := fiber.New()
s := &stubStorage{}
app.Use(New(Config{Storage: s, Lock: &stubLock{}}))
app.Post("/", func(_ fiber.Ctx) error { return errors.New("boom") })
req := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Equal(t, "boom", body)
require.Equal(t, 0, s.setCount)
resp2, body2 := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp2.StatusCode)
require.Equal(t, "boom", body2)
require.Equal(t, 0, s.setCount)
}
func Test_New_LockError(t *testing.T) {
t.Parallel()
app := fiber.New()
l := &stubLock{lockErr: errors.New("fail")}
app.Use(New(Config{Lock: l, Storage: &stubStorage{}}))
app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") })
req := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, body, "failed to lock")
}
func Test_New_StorageSetError(t *testing.T) {
t.Parallel()
app := fiber.New()
s := &stubStorage{setErr: errors.New("nope")}
app.Use(New(Config{Storage: s, Lock: &stubLock{}}))
app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") })
req := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, body, "failed to save response")
}
func Test_New_UnlockError(t *testing.T) {
t.Parallel()
app := fiber.New()
l := &stubLock{unlockErr: errors.New("u")}
app.Use(New(Config{Lock: l, Storage: &stubStorage{}}))
app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") })
req := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "ok", body)
}
func Test_New_SecondPassReadError(t *testing.T) {
t.Parallel()
app := fiber.New()
s := &stubStorage{}
l := &stubLock{afterLock: func() { s.getErr = errors.New("g") }}
app.Use(New(Config{Lock: l, Storage: s}))
app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") })
req := httptest.NewRequest(http.MethodPost, "/", http.NoBody)
req.Header.Set(ConfigDefault.KeyHeader, validKey)
resp, body := do(app, req)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Contains(t, body, "failed to write cached response while locked")
}
================================================
FILE: middleware/idempotency/locker.go
================================================
package idempotency
import (
"sync"
)
// Locker implements a spinlock for a string key.
type Locker interface {
Lock(key string) error
Unlock(key string) error
}
type countedLock struct {
mu sync.Mutex
locked int
}
// MemoryLock coordinates access to idempotency keys using in-memory locks.
type MemoryLock struct {
keys map[string]*countedLock
mu sync.Mutex
}
// Lock acquires the lock for the provided key, creating it when necessary.
func (l *MemoryLock) Lock(key string) error {
l.mu.Lock()
lock, ok := l.keys[key]
if !ok {
lock = new(countedLock)
l.keys[key] = lock
}
lock.locked++
l.mu.Unlock()
lock.mu.Lock()
return nil
}
// Unlock releases the lock associated with the provided key.
func (l *MemoryLock) Unlock(key string) error {
l.mu.Lock()
lock, ok := l.keys[key]
if !ok {
// This happens if we try to unlock an unknown key
l.mu.Unlock()
return nil
}
l.mu.Unlock()
lock.mu.Unlock()
l.mu.Lock()
lock.locked--
if lock.locked <= 0 {
// This happens if countedLock is used to Lock and Unlock the same number of times
// So, we can delete the key to prevent memory leak
delete(l.keys, key)
}
l.mu.Unlock()
return nil
}
// NewMemoryLock creates a MemoryLock ready for use.
func NewMemoryLock() *MemoryLock {
return &MemoryLock{
keys: make(map[string]*countedLock),
}
}
var _ Locker = (*MemoryLock)(nil)
================================================
FILE: middleware/idempotency/locker_test.go
================================================
package idempotency_test
import (
"strconv"
"sync/atomic"
"testing"
"time"
"github.com/gofiber/fiber/v3/middleware/idempotency"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// go test -run Test_MemoryLock
func Test_MemoryLock(t *testing.T) {
t.Parallel()
l := idempotency.NewMemoryLock()
// Test that a lock can be acquired
{
err := l.Lock("a")
require.NoError(t, err)
}
// Test that the same lock cannot be acquired again while held
{
done := make(chan struct{})
go func() {
defer close(done)
err := l.Lock("a")
assert.NoError(t, err)
}()
select {
case <-done:
t.Fatal("lock acquired again")
case <-time.After(time.Second):
// Expected: goroutine should still be blocked
}
}
// Release lock "a" to prevent goroutine leak
{
err := l.Unlock("a")
require.NoError(t, err)
}
// Test lock and unlock sequence
{
err := l.Lock("b")
require.NoError(t, err)
}
{
err := l.Unlock("b")
require.NoError(t, err)
}
{
err := l.Lock("b")
require.NoError(t, err)
}
{
err := l.Unlock("b")
require.NoError(t, err)
}
// Test unlocking non-existent lock (should succeed)
{
err := l.Unlock("c")
require.NoError(t, err)
}
// Test another lock
{
err := l.Lock("d")
require.NoError(t, err)
}
{
err := l.Unlock("d")
require.NoError(t, err)
}
}
func Benchmark_MemoryLock(b *testing.B) {
keys := make([]string, 50_000_000)
for i := range keys {
keys[i] = strconv.Itoa(i)
}
lock := idempotency.NewMemoryLock()
for i := 0; b.Loop(); i++ {
key := keys[i]
if err := lock.Lock(key); err != nil {
b.Fatal(err)
}
if err := lock.Unlock(key); err != nil {
b.Fatal(err)
}
}
}
func Benchmark_MemoryLock_Parallel(b *testing.B) {
// In order to prevent using repeated keys I pre-allocate keys
keys := make([]string, 1_000_000)
for i := range keys {
keys[i] = strconv.Itoa(i)
}
b.Run("UniqueKeys", func(b *testing.B) {
lock := idempotency.NewMemoryLock()
var keyI atomic.Int32
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(p *testing.PB) {
for p.Next() {
i := int(keyI.Add(1)) % len(keys)
key := keys[i]
if err := lock.Lock(key); err != nil {
b.Fatal(err)
}
if err := lock.Unlock(key); err != nil {
b.Fatal(err)
}
}
})
})
b.Run("RepeatedKeys", func(b *testing.B) {
lock := idempotency.NewMemoryLock()
var keyI atomic.Int32
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(p *testing.PB) {
for p.Next() {
// Division by 3 ensures that index will be repeated exactly 3 times
i := int(keyI.Add(1)) / 3 % len(keys)
key := keys[i]
if err := lock.Lock(key); err != nil {
b.Fatal(err)
}
if err := lock.Unlock(key); err != nil {
b.Fatal(err)
}
}
})
})
}
================================================
FILE: middleware/idempotency/response.go
================================================
package idempotency
// response is a struct that represents the response of a request.
// generation tool `go install github.com/tinylib/msgp@latest`
//
// Idempotency payloads are stored in backing storage, so keep headers/bodies bounded.
//
//go:generate msgp -o=response_msgp.go -tests=true -unexported
type response struct {
Headers map[string][]string `msg:"hs,limit=1024"` // HTTP header count norms are well below this.
Body []byte `msg:"b"` // Idempotency bodies are bounded by storage policy, not msgp limits.
StatusCode int `msg:"sc"`
}
================================================
FILE: middleware/idempotency/response_msgp.go
================================================
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
package idempotency
import (
"github.com/tinylib/msgp/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *response) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "hs":
var zb0002 uint32
zb0002, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err, "Headers")
return
}
if zb0002 > 1024 {
err = msgp.ErrLimitExceeded
return
}
if z.Headers == nil {
z.Headers = make(map[string][]string, zb0002)
} else if len(z.Headers) > 0 {
clear(z.Headers)
}
for zb0002 > 0 {
zb0002--
var za0001 string
za0001, err = dc.ReadString()
if err != nil {
err = msgp.WrapError(err, "Headers")
return
}
var za0002 []string
var zb0003 uint32
zb0003, err = dc.ReadArrayHeader()
if err != nil {
err = msgp.WrapError(err, "Headers", za0001)
return
}
if zb0003 > 1024 {
err = msgp.ErrLimitExceeded
return
}
if cap(za0002) >= int(zb0003) {
za0002 = (za0002)[:zb0003]
} else {
za0002 = make([]string, zb0003)
}
for za0003 := range za0002 {
za0002[za0003], err = dc.ReadString()
if err != nil {
err = msgp.WrapError(err, "Headers", za0001, za0003)
return
}
}
z.Headers[za0001] = za0002
}
case "b":
z.Body, err = dc.ReadBytes(z.Body)
if err != nil {
err = msgp.WrapError(err, "Body")
return
}
case "sc":
z.StatusCode, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "StatusCode")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z *response) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 3
// write "hs"
err = en.Append(0x83, 0xa2, 0x68, 0x73)
if err != nil {
return
}
err = en.WriteMapHeader(uint32(len(z.Headers)))
if err != nil {
err = msgp.WrapError(err, "Headers")
return
}
for za0001, za0002 := range z.Headers {
err = en.WriteString(za0001)
if err != nil {
err = msgp.WrapError(err, "Headers")
return
}
err = en.WriteArrayHeader(uint32(len(za0002)))
if err != nil {
err = msgp.WrapError(err, "Headers", za0001)
return
}
for za0003 := range za0002 {
err = en.WriteString(za0002[za0003])
if err != nil {
err = msgp.WrapError(err, "Headers", za0001, za0003)
return
}
}
}
// write "b"
err = en.Append(0xa1, 0x62)
if err != nil {
return
}
err = en.WriteBytes(z.Body)
if err != nil {
err = msgp.WrapError(err, "Body")
return
}
// write "sc"
err = en.Append(0xa2, 0x73, 0x63)
if err != nil {
return
}
err = en.WriteInt(z.StatusCode)
if err != nil {
err = msgp.WrapError(err, "StatusCode")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z *response) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 3
// string "hs"
o = append(o, 0x83, 0xa2, 0x68, 0x73)
o = msgp.AppendMapHeader(o, uint32(len(z.Headers)))
for za0001, za0002 := range z.Headers {
o = msgp.AppendString(o, za0001)
o = msgp.AppendArrayHeader(o, uint32(len(za0002)))
for za0003 := range za0002 {
o = msgp.AppendString(o, za0002[za0003])
}
}
// string "b"
o = append(o, 0xa1, 0x62)
o = msgp.AppendBytes(o, z.Body)
// string "sc"
o = append(o, 0xa2, 0x73, 0x63)
o = msgp.AppendInt(o, z.StatusCode)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *response) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "hs":
var zb0002 uint32
zb0002, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Headers")
return
}
if zb0002 > 1024 {
err = msgp.ErrLimitExceeded
return
}
if z.Headers == nil {
z.Headers = make(map[string][]string, zb0002)
} else if len(z.Headers) > 0 {
clear(z.Headers)
}
for zb0002 > 0 {
var za0002 []string
zb0002--
var za0001 string
za0001, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Headers")
return
}
var zb0003 uint32
zb0003, bts, err = msgp.ReadArrayHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Headers", za0001)
return
}
if zb0003 > 1024 {
err = msgp.ErrLimitExceeded
return
}
if cap(za0002) >= int(zb0003) {
za0002 = (za0002)[:zb0003]
} else {
za0002 = make([]string, zb0003)
}
for za0003 := range za0002 {
za0002[za0003], bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "Headers", za0001, za0003)
return
}
}
z.Headers[za0001] = za0002
}
case "b":
z.Body, bts, err = msgp.ReadBytesBytes(bts, z.Body)
if err != nil {
err = msgp.WrapError(err, "Body")
return
}
case "sc":
z.StatusCode, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "StatusCode")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *response) Msgsize() (s int) {
s = 1 + 3 + msgp.MapHeaderSize
if z.Headers != nil {
for za0001, za0002 := range z.Headers {
_ = za0002
s += msgp.StringPrefixSize + len(za0001) + msgp.ArrayHeaderSize
for za0003 := range za0002 {
s += msgp.StringPrefixSize + len(za0002[za0003])
}
}
}
s += 2 + msgp.BytesPrefixSize + len(z.Body) + 3 + msgp.IntSize
return
}
================================================
FILE: middleware/idempotency/response_msgp_test.go
================================================
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
package idempotency
import (
"bytes"
"testing"
"github.com/tinylib/msgp/msgp"
)
func TestMarshalUnmarshalresponse(t *testing.T) {
v := response{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgresponse(b *testing.B) {
v := response{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgresponse(b *testing.B) {
v := response{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshalresponse(b *testing.B) {
v := response{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}
func TestEncodeDecoderesponse(t *testing.T) {
v := response{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
m := v.Msgsize()
if buf.Len() > m {
t.Log("WARNING: TestEncodeDecoderesponse Msgsize() is inaccurate")
}
vn := response{}
err := msgp.Decode(&buf, &vn)
if err != nil {
t.Error(err)
}
buf.Reset()
msgp.Encode(&buf, &v)
err = msgp.NewReader(&buf).Skip()
if err != nil {
t.Error(err)
}
}
func BenchmarkEncoderesponse(b *testing.B) {
v := response{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
en := msgp.NewWriter(msgp.Nowhere)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.EncodeMsg(en)
}
en.Flush()
}
func BenchmarkDecoderesponse(b *testing.B) {
v := response{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
rd := msgp.NewEndlessReader(buf.Bytes(), b)
dc := msgp.NewReader(rd)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := v.DecodeMsg(dc)
if err != nil {
b.Fatal(err)
}
}
}
================================================
FILE: middleware/idempotency/stub_test.go
================================================
package idempotency
import (
"context"
"time"
)
// stubLock implements Locker for testing purposes.
type stubLock struct {
lockErr error
unlockErr error
afterLock func()
}
func (s *stubLock) Lock(string) error {
if s.afterLock != nil {
s.afterLock()
}
return s.lockErr
}
func (s *stubLock) Unlock(string) error { return s.unlockErr }
// stubStorage implements fiber.Storage for testing.
type stubStorage struct {
data map[string][]byte
getErr error
setErr error
setCount int
}
func (s *stubStorage) Get(key string) ([]byte, error) {
if s.getErr != nil {
return nil, s.getErr
}
if s.data == nil {
return nil, nil
}
return s.data[key], nil
}
func (s *stubStorage) GetWithContext(_ context.Context, key string) ([]byte, error) {
// Call Get method to avoid code duplication
return s.Get(key)
}
func (s *stubStorage) Set(key string, val []byte, _ time.Duration) error {
if s.setErr != nil {
return s.setErr
}
if s.data == nil {
s.data = make(map[string][]byte)
}
s.data[key] = val
s.setCount++
return nil
}
func (s *stubStorage) SetWithContext(_ context.Context, key string, val []byte, _ time.Duration) error {
// Call Set method to avoid code duplication
return s.Set(key, val, 0)
}
func (s *stubStorage) Delete(key string) error {
if s.data != nil {
delete(s.data, key)
}
return nil
}
func (s *stubStorage) DeleteWithContext(_ context.Context, key string) error {
// Call Delete method to avoid code duplication
return s.Delete(key)
}
func (s *stubStorage) Reset() error {
s.data = make(map[string][]byte)
return nil
}
func (s *stubStorage) ResetWithContext(_ context.Context) error {
// Call Reset method to avoid code duplication
return s.Reset()
}
func (*stubStorage) Close() error { return nil }
================================================
FILE: middleware/keyauth/config.go
================================================
package keyauth
import (
"fmt"
"net/url"
"strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
)
const (
ErrorInvalidRequest = "invalid_request"
ErrorInvalidToken = "invalid_token"
ErrorInsufficientScope = "insufficient_scope"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// SuccessHandler defines a function which is executed for a valid key.
//
// Optional. Default: c.Next()
SuccessHandler fiber.Handler
// ErrorHandler defines a function which is executed for an invalid key.
// It may be used to define a custom error.
//
// Optional. Default: 401 Missing or invalid API Key
ErrorHandler fiber.ErrorHandler
// Validator is a function to validate the key.
//
// Required.
Validator func(c fiber.Ctx, key string) (bool, error)
// Realm defines the protected area for WWW-Authenticate responses.
// This is used to set the `WWW-Authenticate` header when authentication fails.
//
// Optional. Default value "Restricted".
Realm string
// Challenge defines the full `WWW-Authenticate` header value used when
// the middleware responds with 401 and no Authorization scheme is
// present.
//
// Optional. Default: `ApiKey realm=""` when no Authorization scheme
// is configured.
Challenge string
// Error is the RFC 6750 `error` parameter appended to Bearer
// `WWW-Authenticate` challenges when validation fails. Allowed values
// are `invalid_request`, `invalid_token`, or `insufficient_scope`.
//
// Optional. Default: "".
Error string
// ErrorDescription is the RFC 6750 `error_description` parameter
// appended to Bearer `WWW-Authenticate` challenges when validation
// fails. This field requires that `Error` is also set.
//
// Optional. Default: "".
ErrorDescription string
// ErrorURI is the RFC 6750 `error_uri` parameter appended to Bearer
// `WWW-Authenticate` challenges when validation fails. This field
// requires that `Error` is also set.
//
// Optional. Default: "".
ErrorURI string
// Scope is the RFC 6750 `scope` parameter appended to Bearer
// challenges when the `error` is `insufficient_scope`. This field
// requires that `Error` is set to `insufficient_scope`.
//
// Optional. Default: "".
Scope string
// Extractor is a function to extract the key from the request.
//
// Optional. Default: extractors.FromAuthHeader("Bearer")
Extractor extractors.Extractor
}
// ConfigDefault is the default config
var ConfigDefault = Config{
SuccessHandler: func(c fiber.Ctx) error {
return c.Next()
},
ErrorHandler: func(c fiber.Ctx, _ error) error {
return c.Status(fiber.StatusUnauthorized).SendString(ErrMissingOrMalformedAPIKey.Error())
},
Realm: "Restricted",
Extractor: extractors.FromAuthHeader("Bearer"),
}
// configDefault is a helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
panic("fiber: keyauth middleware requires a validator function")
}
cfg := config[0]
// Require a validator function
if cfg.Validator == nil {
panic("fiber: keyauth middleware requires a validator function")
}
// Set default values
if cfg.Extractor.Extract == nil {
cfg.Extractor = ConfigDefault.Extractor
}
if cfg.Realm == "" {
cfg.Realm = ConfigDefault.Realm
}
if cfg.SuccessHandler == nil {
cfg.SuccessHandler = ConfigDefault.SuccessHandler
}
if cfg.ErrorHandler == nil {
cfg.ErrorHandler = ConfigDefault.ErrorHandler
}
if len(getAuthSchemes(cfg.Extractor)) == 0 && cfg.Challenge == "" {
cfg.Challenge = fmt.Sprintf("ApiKey realm=%q", cfg.Realm)
}
if cfg.Error != "" {
switch cfg.Error {
case ErrorInvalidRequest, ErrorInvalidToken, ErrorInsufficientScope:
default:
panic("fiber: keyauth unsupported error token")
}
}
if cfg.ErrorDescription != "" && cfg.Error == "" {
panic("fiber: keyauth error_description requires error")
}
if cfg.ErrorURI != "" {
if cfg.Error == "" {
panic("fiber: keyauth error_uri requires error")
}
if u, err := url.Parse(cfg.ErrorURI); err != nil || !u.IsAbs() {
panic("fiber: keyauth error_uri must be absolute")
}
}
if cfg.Error == ErrorInsufficientScope {
if cfg.Scope == "" {
panic("fiber: keyauth insufficient_scope requires scope")
}
for scope := range strings.SplitSeq(cfg.Scope, " ") {
if scope == "" || !isScopeToken(scope) {
panic("fiber: keyauth scope contains invalid token")
}
}
} else if cfg.Scope != "" {
panic("fiber: keyauth scope requires insufficient_scope error")
}
return cfg
}
func isScopeToken(s string) bool {
for i := 0; i < len(s); i++ {
c := s[i]
if c < 0x21 || c > 0x7e || c == '"' || c == '\\' {
return false
}
}
return s != ""
}
================================================
FILE: middleware/keyauth/config_test.go
================================================
package keyauth
import (
"reflect"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test_KeyAuth_ConfigDefault_NoConfig tests the case where no config is provided.
func Test_KeyAuth_ConfigDefault_NoConfig(t *testing.T) {
t.Parallel()
// The New function will call configDefault with no arguments
// which will panic because ConfigDefault.Validator is nil.
assert.PanicsWithValue(t, "fiber: keyauth middleware requires a validator function", func() {
New()
}, "Calling New() without a validator should panic")
}
// Test_KeyAuth_ConfigDefault_PanicWithoutValidator tests that configDefault panics when Validator is nil.
func Test_KeyAuth_ConfigDefault_PanicWithoutValidator(t *testing.T) {
t.Parallel()
assert.PanicsWithValue(t, "fiber: keyauth middleware requires a validator function", func() {
configDefault(Config{})
}, "configDefault should panic if validator is not provided")
}
// Test_KeyAuth_ConfigDefault_WithValidator tests that default values are set when only a validator is provided.
func Test_KeyAuth_ConfigDefault_WithValidator(t *testing.T) {
t.Parallel()
validator := func(fiber.Ctx, string) (bool, error) { return true, nil }
cfg := configDefault(Config{
Validator: validator,
})
require.NotNil(t, cfg.Validator)
assert.Equal(t, ConfigDefault.Realm, cfg.Realm)
require.NotNil(t, cfg.SuccessHandler)
require.NotNil(t, cfg.ErrorHandler)
require.NotNil(t, cfg.Extractor.Extract)
}
// Test_KeyAuth_ConfigDefault_CustomConfig tests that custom values are preserved.
func Test_KeyAuth_ConfigDefault_CustomConfig(t *testing.T) {
t.Parallel()
nextFunc := func(_ fiber.Ctx) bool { return true }
successHandler := func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }
errorHandler := func(c fiber.Ctx, _ error) error { return c.SendStatus(fiber.StatusForbidden) }
validator := func(_ fiber.Ctx, _ string) (bool, error) { return true, nil }
extractor := extractors.FromHeader("X-API-Key")
cfg := configDefault(Config{
Next: nextFunc,
SuccessHandler: successHandler,
ErrorHandler: errorHandler,
Validator: validator,
Realm: "API",
Extractor: extractor,
})
// Using reflect.ValueOf to compare function pointers
assert.Equal(t, reflect.ValueOf(nextFunc).Pointer(), reflect.ValueOf(cfg.Next).Pointer())
assert.Equal(t, reflect.ValueOf(successHandler).Pointer(), reflect.ValueOf(cfg.SuccessHandler).Pointer())
assert.Equal(t, reflect.ValueOf(errorHandler).Pointer(), reflect.ValueOf(cfg.ErrorHandler).Pointer())
assert.Equal(t, reflect.ValueOf(validator).Pointer(), reflect.ValueOf(cfg.Validator).Pointer())
assert.Equal(t, reflect.ValueOf(extractor.Extract).Pointer(), reflect.ValueOf(cfg.Extractor.Extract).Pointer())
assert.Equal(t, "API", cfg.Realm)
}
================================================
FILE: middleware/keyauth/keyauth.go
================================================
package keyauth
import (
"errors"
"fmt"
"strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
"github.com/gofiber/utils/v2"
)
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
// The keys for the values in context
const (
tokenKey contextKey = iota
)
// ErrMissingOrMalformedAPIKey is returned when the API key is missing or invalid.
var ErrMissingOrMalformedAPIKey = errors.New("missing or invalid API Key")
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Init config
cfg := configDefault(config...)
// Determine the auth schemes from the extractor chain.
authSchemes := getAuthSchemes(cfg.Extractor)
// Return middleware handler
return func(c fiber.Ctx) error {
// Filter request to skip middleware
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Extract and verify key
key, err := cfg.Extractor.Extract(c)
if errors.Is(err, extractors.ErrNotFound) {
// Replace shared extractor not found error with a keyauth specific error
err = ErrMissingOrMalformedAPIKey
}
// If there was no error extracting the key, validate it
if err == nil {
var valid bool
valid, err = cfg.Validator(c, key)
if err == nil && valid {
fiber.StoreInContext(c, tokenKey, key)
return cfg.SuccessHandler(c)
}
}
// Execute the error handler first
handlerErr := cfg.ErrorHandler(c, err)
status := c.Response().StatusCode()
if status == fiber.StatusUnauthorized || status == fiber.StatusProxyAuthRequired {
header := fiber.HeaderWWWAuthenticate
if status == fiber.StatusProxyAuthRequired {
header = fiber.HeaderProxyAuthenticate
}
if len(authSchemes) > 0 {
challenges := make([]string, 0, len(authSchemes))
for _, scheme := range authSchemes {
var b strings.Builder
fmt.Fprintf(&b, "%s realm=%q", scheme, cfg.Realm)
if utils.EqualFold(scheme, "Bearer") {
if cfg.Error != "" {
fmt.Fprintf(&b, ", error=%q", cfg.Error)
if cfg.ErrorDescription != "" {
fmt.Fprintf(&b, ", error_description=%q", cfg.ErrorDescription)
}
if cfg.ErrorURI != "" {
fmt.Fprintf(&b, ", error_uri=%q", cfg.ErrorURI)
}
if cfg.Error == ErrorInsufficientScope {
fmt.Fprintf(&b, ", scope=%q", cfg.Scope)
}
}
}
challenges = append(challenges, b.String())
}
c.Set(header, strings.Join(challenges, ", "))
} else if cfg.Challenge != "" {
c.Set(header, cfg.Challenge)
}
}
return handlerErr
}
}
// TokenFromContext returns the bearer token from the request context.
// It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context.
// It returns an empty string if the token does not exist.
func TokenFromContext(ctx any) string {
if token, ok := fiber.ValueFromContext[string](ctx, tokenKey); ok {
return token
}
return ""
}
// getAuthSchemes inspects an extractor and its chain to find all auth schemes
// used by FromAuthHeader. It returns a slice of schemes, or an empty slice if
// none are found.
func getAuthSchemes(e extractors.Extractor) []string {
var schemes []string
if e.Source == extractors.SourceAuthHeader && e.AuthScheme != "" {
schemes = append(schemes, e.AuthScheme)
}
for _, ex := range e.Chain {
schemes = append(schemes, getAuthSchemes(ex)...)
}
return schemes
}
================================================
FILE: middleware/keyauth/keyauth_test.go
================================================
package keyauth
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
)
const CorrectKey = "correct-token_123./~+"
var testConfig = fiber.TestConfig{
Timeout: 0,
}
const (
paramExtractorName = "param"
formExtractorName = "form"
queryExtractorName = "query"
headerExtractorName = "header"
authHeaderExtractorName = "authHeader"
cookieExtractorName = "cookie"
)
func Test_AuthSources(t *testing.T) {
// define test cases
testSources := []string{headerExtractorName, authHeaderExtractorName, cookieExtractorName, queryExtractorName, paramExtractorName, formExtractorName}
tests := []struct {
route string
authTokenName string
description string
APIKey string
expectedBody string
expectedCode int
}{
{
route: "/",
authTokenName: "access_token",
description: "auth with correct key",
APIKey: CorrectKey,
expectedCode: 200,
expectedBody: "Success!",
},
{
route: "/",
authTokenName: "access_token",
description: "auth with no key",
APIKey: "",
expectedCode: 401, // 404 in case of param authentication
expectedBody: ErrMissingOrMalformedAPIKey.Error(),
},
{
route: "/",
authTokenName: "access_token",
description: "auth with wrong key",
APIKey: "WRONGKEY",
expectedCode: 401,
expectedBody: ErrMissingOrMalformedAPIKey.Error(),
},
}
for _, authSource := range testSources {
t.Run(authSource, func(t *testing.T) {
for _, test := range tests {
app := fiber.New(fiber.Config{UnescapePath: true})
testKey := test.APIKey
correctKey := CorrectKey
// Use a simple key for param and cookie to avoid encoding issues in the test setup
if authSource == paramExtractorName || authSource == cookieExtractorName {
if test.APIKey != "" && test.APIKey != "WRONGKEY" {
testKey = "simple-key"
correctKey = "simple-key"
}
}
authMiddleware := New(Config{
Extractor: func() extractors.Extractor {
switch authSource {
case headerExtractorName:
return extractors.FromHeader(test.authTokenName)
case authHeaderExtractorName:
return extractors.FromAuthHeader("Bearer")
case cookieExtractorName:
return extractors.FromCookie(test.authTokenName)
case queryExtractorName:
return extractors.FromQuery(test.authTokenName)
case paramExtractorName:
return extractors.FromParam(test.authTokenName)
case formExtractorName:
return extractors.FromForm(test.authTokenName)
default:
panic("unknown source")
}
}(),
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == correctKey {
return true, nil
}
return false, errors.New("invalid key")
},
})
handler := func(c fiber.Ctx) error {
return c.SendString("Success!")
}
method := fiber.MethodGet
switch authSource {
case paramExtractorName:
app.Get("/:"+test.authTokenName, authMiddleware, handler)
case formExtractorName:
method = fiber.MethodPost
app.Post("/", authMiddleware, handler)
default:
app.Get("/", authMiddleware, handler)
}
targetURL := "/"
if authSource == paramExtractorName {
targetURL = "/" + url.PathEscape(testKey)
}
var reqBody io.Reader
if authSource == formExtractorName {
form := url.Values{}
form.Add(test.authTokenName, testKey)
bodyStr := form.Encode()
reqBody = strings.NewReader(bodyStr)
}
req, err := http.NewRequestWithContext(context.Background(), method, targetURL, reqBody)
require.NoError(t, err)
switch authSource {
case headerExtractorName:
req.Header.Set(test.authTokenName, testKey)
case authHeaderExtractorName:
if testKey != "" {
req.Header.Set("Authorization", "Bearer "+testKey)
}
case cookieExtractorName:
req.Header.Set("Cookie", test.authTokenName+"="+testKey)
case queryExtractorName:
q := req.URL.Query()
q.Add(test.authTokenName, testKey)
req.URL.RawQuery = q.Encode()
case formExtractorName:
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
default:
// nothing to do for paramExtractorName
}
res, err := app.Test(req, testConfig)
require.NoError(t, err, test.description)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
errClose := res.Body.Close()
require.NoError(t, errClose)
expectedCode := test.expectedCode
expectedBody := test.expectedBody
if test.APIKey == "" || test.APIKey == "WRONGKEY" {
expectedBody = ErrMissingOrMalformedAPIKey.Error()
}
if authSource == paramExtractorName && testKey == "" {
expectedCode = 404
expectedBody = "Not Found"
}
require.Equal(t, expectedCode, res.StatusCode, test.description)
require.Equal(t, expectedBody, string(body), test.description)
}
})
}
}
func TestMultipleKeyLookup(t *testing.T) {
const (
desc = "auth with correct key"
success = "Success!"
scheme = "Bearer"
)
// set up the fiber endpoint
app := fiber.New()
customExtractor := extractors.Chain(
extractors.FromAuthHeader("Bearer"),
extractors.FromHeader("key"),
extractors.FromCookie("key"),
extractors.FromQuery("key"),
)
authMiddleware := New(Config{
Extractor: customExtractor,
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, errors.New("invalid key")
},
})
app.Use(authMiddleware)
app.Get("/foo", func(c fiber.Ctx) error {
return c.SendString(success)
})
// construct the test HTTP request
var (
req *http.Request
err error
)
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", http.NoBody)
require.NoError(t, err)
q := req.URL.Query()
q.Add("key", CorrectKey)
req.URL.RawQuery = q.Encode()
res, err := app.Test(req, testConfig)
require.NoError(t, err)
// test the body of the request
body, err := io.ReadAll(res.Body)
require.Equal(t, 200, res.StatusCode, desc)
// body
require.NoError(t, err)
require.Equal(t, success, string(body), desc)
err = res.Body.Close()
require.NoError(t, err)
// construct a second request without proper key
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", http.NoBody)
require.NoError(t, err)
res, err = app.Test(req, testConfig)
require.NoError(t, err)
errBody, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(errBody))
}
func Test_MultipleKeyAuth(t *testing.T) {
// set up the fiber endpoint
app := fiber.New()
// set up keyauth for /auth1
app.Use(New(Config{
Next: func(c fiber.Ctx) bool {
return c.Path() != "/auth1"
},
Extractor: extractors.FromAuthHeader("Bearer"),
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == "password1" {
return true, nil
}
return false, errors.New("invalid key")
},
}))
// setup keyauth for /auth2
app.Use(New(Config{
Next: func(c fiber.Ctx) bool {
return c.Path() != "/auth2"
},
Extractor: extractors.FromAuthHeader("Bearer"),
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == "password2" {
return true, nil
}
return false, errors.New("invalid key")
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("No auth needed!")
})
app.Get("/auth1", func(c fiber.Ctx) error {
return c.SendString("Successfully authenticated for auth1!")
})
app.Get("/auth2", func(c fiber.Ctx) error {
return c.SendString("Successfully authenticated for auth2!")
})
// define test cases
tests := []struct {
route string
description string
APIKey string
expectedBody string
expectedCode int
}{
// No auth needed for /
{
route: "/",
description: "No password needed",
APIKey: "",
expectedCode: 200,
expectedBody: "No auth needed!",
},
// auth needed for auth1
{
route: "/auth1",
description: "Normal Authentication Case",
APIKey: "password1",
expectedCode: 200,
expectedBody: "Successfully authenticated for auth1!",
},
{
route: "/auth1",
description: "Wrong API Key",
APIKey: "WRONG KEY",
expectedCode: 401,
expectedBody: ErrMissingOrMalformedAPIKey.Error(),
},
{
route: "/auth1",
description: "Wrong API Key",
APIKey: "", // NO KEY
expectedCode: 401,
expectedBody: ErrMissingOrMalformedAPIKey.Error(),
},
// Auth 2 has a different password
{
route: "/auth2",
description: "Normal Authentication Case for auth2",
APIKey: "password2",
expectedCode: 200,
expectedBody: "Successfully authenticated for auth2!",
},
{
route: "/auth2",
description: "Wrong API Key",
APIKey: "WRONG KEY",
expectedCode: 401,
expectedBody: ErrMissingOrMalformedAPIKey.Error(),
},
{
route: "/auth2",
description: "Wrong API Key",
APIKey: "", // NO KEY
expectedCode: 401,
expectedBody: ErrMissingOrMalformedAPIKey.Error(),
},
}
// run the tests
for _, test := range tests {
var req *http.Request
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, test.route, http.NoBody)
require.NoError(t, err)
if test.APIKey != "" {
req.Header.Set("Authorization", "Bearer "+test.APIKey)
}
res, err := app.Test(req, testConfig)
require.NoError(t, err, test.description)
// test the body of the request
body, err := io.ReadAll(res.Body)
require.Equal(t, test.expectedCode, res.StatusCode, test.description)
// body
require.NoError(t, err, test.description)
require.Equal(t, test.expectedBody, string(body), test.description)
}
}
func Test_CustomSuccessAndFailureHandlers(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
SuccessHandler: func(c fiber.Ctx) error {
return c.Status(fiber.StatusOK).SendString("API key is valid and request was handled by custom success handler")
},
ErrorHandler: func(c fiber.Ctx, _ error) error {
return c.Status(fiber.StatusUnauthorized).SendString("API key is invalid and request was handled by custom error handler")
},
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler that should not be called
app.Get("/", func(_ fiber.Ctx) error {
t.Error("Test handler should not be called")
return nil
})
// Create a request without an API key and send it to the app
res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, "API key is invalid and request was handled by custom error handler", string(body))
// Create a request with a valid API key in the Authorization header
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Bearer "+CorrectKey)
// Send the request to the app
res, err = app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err = io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, "API key is valid and request was handled by custom success handler", string(body))
}
func Test_CustomNextFunc(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Next: func(c fiber.Ctx) bool {
return c.Path() == "/allowed"
},
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler
app.Get("/allowed", func(c fiber.Ctx) error {
return c.SendString("API key is valid and request was allowed by custom filter")
})
app.Get("/not-allowed", func(c fiber.Ctx) error {
return c.SendString("Should be protected")
})
// Create a request with the "/allowed" path and send it to the app
req := httptest.NewRequest(fiber.MethodGet, "/allowed", http.NoBody)
res, err := app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, "API key is valid and request was allowed by custom filter", string(body))
// Create a request with a different path and send it to the app without correct key
req = httptest.NewRequest(fiber.MethodGet, "/not-allowed", http.NoBody)
res, err = app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err = io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body))
// Create a request with a different path and send it to the app with correct key
req = httptest.NewRequest(fiber.MethodGet, "/not-allowed", http.NoBody)
req.Header.Add("Authorization", "Bearer "+CorrectKey)
res, err = app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err = io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, "Should be protected", string(body))
}
func Test_TokenFromContext_None(t *testing.T) {
app := fiber.New()
// Define a test handler that checks TokenFromContext
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(TokenFromContext(c))
})
// Verify a "" is sent back if nothing sets the token on the context.
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
// Send
res, err := app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Empty(t, body)
}
func Test_TokenFromContext(t *testing.T) {
app := fiber.New()
// Wire up keyauth middleware to set TokenFromContext now
app.Use(New(Config{
Extractor: extractors.FromAuthHeader("Basic"),
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler that checks TokenFromContext
app.Get("/", func(c fiber.Ctx) error {
return c.SendString(TokenFromContext(c))
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Basic "+CorrectKey)
// Send
res, err := app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, CorrectKey, string(body))
}
func Test_TokenFromContext_Types(t *testing.T) {
t.Parallel()
app := fiber.New(fiber.Config{PassLocalsToContext: true})
app.Use(New(Config{
Extractor: extractors.FromAuthHeader("Basic"),
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
app.Get("/", func(c fiber.Ctx) error {
require.Equal(t, CorrectKey, TokenFromContext(c))
customCtx, ok := c.(fiber.CustomCtx)
require.True(t, ok)
require.Equal(t, CorrectKey, TokenFromContext(customCtx))
require.Equal(t, CorrectKey, TokenFromContext(c.RequestCtx()))
require.Equal(t, CorrectKey, TokenFromContext(c.Context()))
return c.SendStatus(fiber.StatusOK)
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Basic "+CorrectKey)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
func Test_AuthSchemeToken(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Extractor: extractors.FromAuthHeader("Token"),
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("API key is valid")
})
// Create a request with a valid API key in the "Token" Authorization header
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Token "+CorrectKey)
// Send the request to the app
res, err := app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, "API key is valid", string(body))
}
func Test_AuthSchemeBasic(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Extractor: extractors.FromAuthHeader("Basic"),
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
// Define a test handler
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("API key is valid")
})
// Create a request without an API key and Send the request to the app
res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
// Read the response body into a string
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body))
// Create a request with a valid API key in the "Authorization" header using the "Basic" scheme
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Basic "+CorrectKey)
// Send the request to the app
res, err = app.Test(req)
require.NoError(t, err)
// Read the response body into a string
body, err = io.ReadAll(res.Body)
require.NoError(t, err)
// Check that the response has the expected status code and body
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, "API key is valid", string(body))
}
func Test_HeaderSchemeCaseInsensitive(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "bearer "+CorrectKey)
res, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, "OK", string(body))
}
func Test_DefaultErrorHandlerChallenge(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, ErrMissingOrMalformedAPIKey
},
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, "Bearer realm=\"Restricted\"", res.Header.Get("WWW-Authenticate"))
}
func Test_DefaultErrorHandlerInvalid(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, errors.New("invalid")
},
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Bearer "+CorrectKey)
res, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body))
require.Equal(t, "Bearer realm=\"Restricted\"", res.Header.Get("WWW-Authenticate"))
}
func Test_HeaderSchemeMultipleSpaces(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Validator: func(_ fiber.Ctx, key string) (bool, error) {
if key == CorrectKey {
return true, nil
}
return false, ErrMissingOrMalformedAPIKey
},
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Bearer "+CorrectKey)
res, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body))
}
func Test_HeaderSchemeMissingSpace(t *testing.T) {
app := fiber.New()
app.Use(New(Config{Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, ErrMissingOrMalformedAPIKey
}}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Bearer"+CorrectKey)
res, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body))
}
func Test_HeaderSchemeNoToken(t *testing.T) {
app := fiber.New()
app.Use(New(Config{Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, ErrMissingOrMalformedAPIKey
}}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Bearer ")
res, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body))
}
func Test_HeaderSchemeNoSeparator(t *testing.T) {
app := fiber.New()
app.Use(New(Config{Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, ErrMissingOrMalformedAPIKey
}}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
// No space between "Bearer" and token
req.Header.Add("Authorization", "BearerTokenWithoutSpace")
res, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body))
}
func Test_HeaderSchemeEmptyTokenAfterTrim(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, ErrMissingOrMalformedAPIKey
},
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
// Authorization header with scheme followed by only spaces/tabs (no actual token)
req.Header.Add("Authorization", "Bearer \t \t ")
res, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body))
}
func Test_WWWAuthenticateHeader(t *testing.T) {
t.Parallel()
tests := []struct {
name string
expectedHeader string
config Config
expectedStatusCode int
}{
{
name: "default config on failure",
config: Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, errors.New("validation failed")
},
},
expectedHeader: `Bearer realm="Restricted"`,
expectedStatusCode: fiber.StatusUnauthorized,
},
{
name: "custom realm on failure",
config: Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, errors.New("validation failed")
},
Realm: "My Custom Realm",
},
expectedHeader: `Bearer realm="My Custom Realm"`,
expectedStatusCode: fiber.StatusUnauthorized,
},
{
name: "default header for non-auth-header extractor",
config: Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, errors.New("validation failed")
},
Extractor: extractors.FromQuery("api_key"),
},
expectedHeader: `ApiKey realm="Restricted"`,
expectedStatusCode: fiber.StatusUnauthorized,
},
{
name: "no header on success",
config: Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return true, nil
},
},
expectedHeader: "",
expectedStatusCode: fiber.StatusOK,
},
{
name: "chained extractor with auth header",
config: Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, errors.New("validation failed")
},
Extractor: extractors.Chain(extractors.FromQuery("q"), extractors.FromAuthHeader("MyScheme")),
},
expectedHeader: `MyScheme realm="Restricted"`,
expectedStatusCode: fiber.StatusUnauthorized,
},
{
name: "chained extractor without auth header",
config: Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, errors.New("validation failed")
},
Extractor: extractors.Chain(extractors.FromQuery("q"), extractors.FromCookie("c")),
},
expectedHeader: `ApiKey realm="Restricted"`,
expectedStatusCode: fiber.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(tt.config))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("OK")
})
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
// Provide a key for the default extractor to find
if tt.config.Extractor.Extract == nil {
req.Header.Set(fiber.HeaderAuthorization, "Bearer somekey")
}
resp, err := app.Test(req)
require.NoError(t, err)
assert.Equal(t, tt.expectedStatusCode, resp.StatusCode)
assert.Equal(t, tt.expectedHeader, resp.Header.Get(fiber.HeaderWWWAuthenticate))
})
}
}
func Test_CustomChallenge(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Extractor: extractors.FromQuery("api_key"),
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, errors.New("invalid")
},
Challenge: `ApiKey realm="Restricted"`,
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
res, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, `ApiKey realm="Restricted"`, res.Header.Get("WWW-Authenticate"))
}
func Test_BearerErrorFields(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, errors.New("invalid")
},
Error: "invalid_token",
ErrorDescription: "token expired",
ErrorURI: "https://example.com",
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Bearer something")
res, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, `Bearer realm="Restricted", error="invalid_token", error_description="token expired", error_uri="https://example.com"`, res.Header.Get("WWW-Authenticate"))
}
func Test_BearerErrorURIOnly(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, errors.New("invalid")
},
Error: "invalid_token",
ErrorURI: "https://example.com/docs",
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Bearer something")
res, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, `Bearer realm="Restricted", error="invalid_token", error_uri="https://example.com/docs"`, res.Header.Get("WWW-Authenticate"))
}
func Test_BearerInsufficientScope(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, errors.New("invalid")
},
Error: ErrorInsufficientScope,
Scope: "read",
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Bearer something")
res, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, `Bearer realm="Restricted", error="insufficient_scope", scope="read"`, res.Header.Get("WWW-Authenticate"))
}
func Test_ScopeValidation(t *testing.T) {
require.PanicsWithValue(t, "fiber: keyauth scope requires insufficient_scope error", func() {
New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil },
Scope: "foo",
})
})
require.PanicsWithValue(t, "fiber: keyauth insufficient_scope requires scope", func() {
New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil },
Error: ErrorInsufficientScope,
})
})
require.PanicsWithValue(t, "fiber: keyauth scope contains invalid token", func() {
New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil },
Error: ErrorInsufficientScope,
Scope: "read \"write\"",
})
})
require.NotPanics(t, func() {
New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil },
Error: ErrorInsufficientScope,
Scope: "read write:all",
})
})
}
func Test_WWWAuthenticateOnlyOn401(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) {
return false, errors.New("invalid")
},
ErrorHandler: func(c fiber.Ctx, _ error) error {
return c.Status(fiber.StatusForbidden).SendString("forbidden")
},
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Bearer bad")
res, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusForbidden, res.StatusCode)
require.Empty(t, res.Header.Get("WWW-Authenticate"))
}
func Test_DefaultChallengeForNonAuthExtractor(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Extractor: extractors.FromQuery("api_key"),
Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, ErrMissingOrMalformedAPIKey },
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, `ApiKey realm="Restricted"`, res.Header.Get(fiber.HeaderWWWAuthenticate))
}
func Test_MultipleWWWAuthenticateChallenges(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Extractor: extractors.Chain(
extractors.FromAuthHeader("Bearer"),
extractors.FromAuthHeader("ApiKey"),
),
Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("invalid") },
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, `Bearer realm="Restricted", ApiKey realm="Restricted"`, res.Header.Get(fiber.HeaderWWWAuthenticate))
}
func Test_ProxyAuthenticateHeader(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("invalid") },
ErrorHandler: func(c fiber.Ctx, _ error) error {
return c.Status(fiber.StatusProxyAuthRequired).SendString("proxy auth")
},
}))
app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") })
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add("Authorization", "Bearer bad")
res, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusProxyAuthRequired, res.StatusCode)
require.Equal(t, `Bearer realm="Restricted"`, res.Header.Get(fiber.HeaderProxyAuthenticate))
require.Empty(t, res.Header.Get(fiber.HeaderWWWAuthenticate))
}
func Test_New_InvalidErrorToken(t *testing.T) {
assert.Panics(t, func() {
New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil },
Error: "unsupported",
})
})
}
func Test_New_ErrorDescriptionRequiresError(t *testing.T) {
assert.Panics(t, func() {
New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil },
ErrorDescription: "desc",
})
})
}
func Test_New_ErrorURIRequiresError(t *testing.T) {
assert.PanicsWithValue(t, "fiber: keyauth error_uri requires error", func() {
New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil },
ErrorURI: "https://example.com/docs",
})
})
}
func Test_New_ErrorURIAbsolute(t *testing.T) {
assert.PanicsWithValue(t, "fiber: keyauth error_uri must be absolute", func() {
New(Config{
Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil },
Error: "invalid_token",
ErrorURI: "/docs",
})
})
}
================================================
FILE: middleware/limiter/config.go
================================================
package limiter
import (
"time"
"github.com/gofiber/fiber/v3"
)
const defaultLimiterMax = 5
// Config defines the config for middleware.
type Config struct {
// Store is used to store the state of the middleware
//
// Default: an in memory store for this process only
Storage fiber.Storage
// LimiterMiddleware is the struct that implements a limiter middleware.
//
// Default: a new Fixed Window Rate Limiter
LimiterMiddleware Handler
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// A function to dynamically calculate the max requests supported by the rate limiter middleware
//
// Default: func(c fiber.Ctx) int {
// return c.Max
// }
MaxFunc func(c fiber.Ctx) int
// A function to dynamically calculate the expiration time for rate limiter entries
//
// Default: A function that returns the static `Expiration` value from the config.
ExpirationFunc func(c fiber.Ctx) time.Duration
// KeyGenerator allows you to generate custom keys, by default c.IP() is used
//
// Default: func(c fiber.Ctx) string {
// return c.IP()
// }
KeyGenerator func(fiber.Ctx) string
// LimitReached is called when a request hits the limit
//
// Default: func(c fiber.Ctx) error {
// return c.SendStatus(fiber.StatusTooManyRequests)
// }
LimitReached fiber.Handler
// Max number of recent connections during `Expiration` seconds before sending a 429 response
//
// Default: 5
Max int
// Expiration is the time on how long to keep records of requests in memory
//
// Default: 1 * time.Minute
Expiration time.Duration
// When set to true, requests with StatusCode >= 400 won't be counted.
//
// Default: false
SkipFailedRequests bool
// When set to true, requests with StatusCode < 400 won't be counted.
//
// Default: false
SkipSuccessfulRequests bool
// When set to true, the middleware will not include the rate limit headers (X-RateLimit-* and Retry-After) in the response.
//
// Default: false
DisableHeaders bool
// DisableValueRedaction turns off masking limiter keys in logs and error messages when set to true.
//
// Default: false
DisableValueRedaction bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Max: defaultLimiterMax,
Expiration: 1 * time.Minute,
MaxFunc: func(_ fiber.Ctx) int {
return defaultLimiterMax
},
// Note: ExpirationFunc is intentionally nil here so that configDefault()
// can create a proper closure that references the configured Expiration value.
KeyGenerator: func(c fiber.Ctx) string {
return c.IP()
},
LimitReached: func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTooManyRequests)
},
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
DisableHeaders: false,
DisableValueRedaction: false,
LimiterMiddleware: FixedWindow{},
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Use default config if nothing provided
var cfg Config
if len(config) < 1 {
cfg = ConfigDefault
} else {
cfg = config[0]
}
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Max <= 0 {
cfg.Max = ConfigDefault.Max
}
if int(cfg.Expiration.Seconds()) <= 0 {
cfg.Expiration = ConfigDefault.Expiration
}
if cfg.KeyGenerator == nil {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}
if cfg.LimitReached == nil {
cfg.LimitReached = ConfigDefault.LimitReached
}
if cfg.LimiterMiddleware == nil {
cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware
}
if cfg.MaxFunc == nil {
cfg.MaxFunc = func(_ fiber.Ctx) int {
return cfg.Max
}
}
if cfg.ExpirationFunc == nil {
cfg.ExpirationFunc = func(_ fiber.Ctx) time.Duration {
return cfg.Expiration
}
}
return cfg
}
================================================
FILE: middleware/limiter/limiter.go
================================================
package limiter
import (
"errors"
"github.com/gofiber/fiber/v3"
)
const (
// X-RateLimit-* headers
xRateLimitLimit = "X-RateLimit-Limit"
xRateLimitRemaining = "X-RateLimit-Remaining"
xRateLimitReset = "X-RateLimit-Reset"
)
// Handler defines a rate-limiting strategy that can produce a middleware
// handler using the provided configuration.
type Handler interface {
New(config *Config) fiber.Handler
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return the specified middleware handler.
return cfg.LimiterMiddleware.New(&cfg)
}
// getEffectiveStatusCode returns the actual status code, considering both the error and response status
func getEffectiveStatusCode(c fiber.Ctx, err error) int {
// If there's an error and it's a *fiber.Error, use its status code
if err != nil {
var fiberErr *fiber.Error
if errors.As(err, &fiberErr) {
return fiberErr.Code
}
}
// Otherwise, use the response status code
return c.Response().StatusCode()
}
================================================
FILE: middleware/limiter/limiter_fixed.go
================================================
package limiter
import (
"fmt"
"strconv"
"sync"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
)
// FixedWindow implements a fixed-window rate limiting strategy.
type FixedWindow struct{}
// New creates a new fixed window middleware handler
func (FixedWindow) New(cfg *Config) fiber.Handler {
if cfg == nil {
defaultCfg := configDefault()
cfg = &defaultCfg
}
// Limiter variables
mux := &sync.RWMutex{}
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage, !cfg.DisableValueRedaction)
// Update timestamp every second
utils.StartTimeStampUpdater()
// Return new handler
return func(c fiber.Ctx) error {
// Generate maxRequests from generator, if no generator was provided the default value returned is 5
maxRequests := cfg.MaxFunc(c)
// Don't execute middleware if Next returns true or if the max is 0
if (cfg.Next != nil && cfg.Next(c)) || maxRequests == 0 {
return c.Next()
}
// Generate expiration from generator
expirationDuration := cfg.ExpirationFunc(c)
if expirationDuration <= 0 {
expirationDuration = ConfigDefault.Expiration
}
expiration := uint64(expirationDuration.Seconds())
// Get key from request
key := cfg.KeyGenerator(c)
// Lock entry
mux.Lock()
reqCtx := c.Context()
// Get entry from pool and release when finished
e, err := manager.get(reqCtx, key)
if err != nil {
mux.Unlock()
return err
}
// Get timestamp
ts := uint64(utils.Timestamp())
// Set expiration if entry does not exist
if e.exp == 0 {
e.exp = ts + expiration
} else if ts >= e.exp {
// Check if entry is expired
e.currHits = 0
e.exp = ts + expiration
}
// Increment hits
e.currHits++
// Calculate when it resets in seconds
resetInSec := e.exp - ts
// Set how many hits we have left
remaining := maxRequests - e.currHits
// Update storage
if setErr := manager.set(reqCtx, key, e, expirationDuration); setErr != nil {
mux.Unlock()
return fmt.Errorf("limiter: failed to persist state: %w", setErr)
}
// Unlock entry
mux.Unlock()
// Check if hits exceed the max
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
if !cfg.DisableHeaders {
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
}
// Call LimitReached handler
return cfg.LimitReached(c)
}
// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err = c.Next()
// Get the effective status code from either the error or response
statusCode := getEffectiveStatusCode(c, err)
// Check for SkipFailedRequests and SkipSuccessfulRequests
if (cfg.SkipSuccessfulRequests && statusCode < fiber.StatusBadRequest) ||
(cfg.SkipFailedRequests && statusCode >= fiber.StatusBadRequest) {
// Lock entry
mux.Lock()
entry, getErr := manager.get(reqCtx, key)
if getErr != nil {
mux.Unlock()
return getErr
}
e = entry
e.currHits--
remaining++
if setErr := manager.set(reqCtx, key, e, expirationDuration); setErr != nil {
mux.Unlock()
return fmt.Errorf("limiter: failed to persist state: %w", setErr)
}
// Unlock entry
mux.Unlock()
}
// We can continue, update RateLimit headers
if !cfg.DisableHeaders {
c.Set(xRateLimitLimit, strconv.Itoa(maxRequests))
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
}
return err
}
}
================================================
FILE: middleware/limiter/limiter_sliding.go
================================================
package limiter
import (
"fmt"
"math"
"strconv"
"sync"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
)
// SlidingWindow implements the sliding-window rate limiting strategy.
type SlidingWindow struct{}
// New creates a new sliding window middleware handler
func (SlidingWindow) New(cfg *Config) fiber.Handler {
if cfg == nil {
defaultCfg := configDefault()
cfg = &defaultCfg
}
// Limiter variables
mux := &sync.RWMutex{}
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage, !cfg.DisableValueRedaction)
// Update timestamp every second
utils.StartTimeStampUpdater()
// Return new handler
return func(c fiber.Ctx) error {
// Generate maxRequests from generator, if no generator was provided the default value returned is 5
maxRequests := cfg.MaxFunc(c)
// Don't execute middleware if Next returns true or if the max is 0
if (cfg.Next != nil && cfg.Next(c)) || maxRequests == 0 {
return c.Next()
}
// Generate expiration from generator
expirationDuration := cfg.ExpirationFunc(c)
if expirationDuration <= 0 {
expirationDuration = ConfigDefault.Expiration
}
expiration := uint64(expirationDuration.Seconds())
// Get key from request
key := cfg.KeyGenerator(c)
// Lock entry
mux.Lock()
reqCtx := c.Context()
// Get entry from pool and release when finished
e, err := manager.get(reqCtx, key)
if err != nil {
mux.Unlock()
return err
}
// Get timestamp
ts := uint64(utils.Timestamp())
// Rotate window
resetInSec := rotateWindow(e, ts, expiration)
windowExpiresAt := e.exp
// Increment hits
e.currHits++
// weight = time until current window reset / total window length
weight := float64(resetInSec) / float64(expiration)
// rate = request count in previous window - weight + request count in current window
rate := int(math.Ceil(float64(e.prevHits)*weight)) + e.currHits
// Calculate how many hits can be made based on the current rate
remaining := maxRequests - rate
// Update storage. Garbage collect when the next window ends.
// |--------------------------|--------------------------|
// ^ ^ ^ ^
// ts e.exp End sample window End next window
// <------------>
// Reset In Sec
// resetInSec = e.exp - ts - time until end of current window.
// duration + expiration = end of next window.
// Because we don't want to garbage collect in the middle of a window
// we add the expiration to the duration.
// Otherwise, after the end of "sample window", attackers could launch
// a new request with the full window length.
if setErr := manager.set(reqCtx, key, e, ttlDuration(resetInSec, expiration)); setErr != nil {
mux.Unlock()
return fmt.Errorf("limiter: failed to persist state: %w", setErr)
}
// Unlock entry
mux.Unlock()
// Check if hits exceed the allowed maximum for this request
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
if !cfg.DisableHeaders {
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
}
// Call LimitReached handler
return cfg.LimitReached(c)
}
// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err = c.Next()
// Get the effective status code from either the error or response
statusCode := getEffectiveStatusCode(c, err)
skipHit := (cfg.SkipSuccessfulRequests && statusCode < fiber.StatusBadRequest) ||
(cfg.SkipFailedRequests && statusCode >= fiber.StatusBadRequest)
if skipHit || !cfg.DisableHeaders {
// Lock entry
mux.Lock()
entry, getErr := manager.get(reqCtx, key)
if getErr != nil {
mux.Unlock()
return getErr
}
e = entry
ts = uint64(utils.Timestamp())
resetInSec = rotateWindow(e, ts, expiration)
weight = float64(resetInSec) / float64(expiration)
if skipHit {
if counter := bucketForOriginalHit(e, windowExpiresAt, ts, expiration); counter != nil && *counter > 0 {
*counter--
}
}
rate = int(math.Ceil(float64(e.prevHits)*weight)) + e.currHits
remaining = maxRequests - rate
if setErr := manager.set(reqCtx, key, e, ttlDuration(resetInSec, expiration)); setErr != nil {
mux.Unlock()
return fmt.Errorf("limiter: failed to persist state: %w", setErr)
}
// Unlock entry
mux.Unlock()
// We can continue, update RateLimit headers
if !cfg.DisableHeaders {
c.Set(xRateLimitLimit, strconv.Itoa(maxRequests))
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
}
}
return err
}
}
func rotateWindow(e *item, ts, expiration uint64) uint64 {
// Set expiration if entry does not exist
if e.exp == 0 {
e.exp = ts + expiration
} else if ts >= e.exp {
// The entry has expired, handle the expiration.
// Reset the current hits to 0.
elapsed := ts - e.exp
if elapsed >= expiration {
e.prevHits = 0
e.currHits = 0
e.exp = ts + expiration
} else {
e.prevHits = e.currHits
e.currHits = 0
e.exp = ts + expiration - elapsed
}
}
// Calculate when it resets in seconds
return e.exp - ts
}
func bucketForOriginalHit(e *item, requestExpiration, ts, expiration uint64) *int {
if ts < requestExpiration {
return &e.currHits
}
if ts-requestExpiration < expiration {
return &e.prevHits
}
return nil
}
func ttlDuration(resetInSec, expiration uint64) time.Duration {
resetDuration, ok := secondsToDuration(resetInSec)
if !ok {
return time.Duration(math.MaxInt64)
}
expirationDuration, ok := secondsToDuration(expiration)
if !ok {
return time.Duration(math.MaxInt64)
}
if resetDuration > time.Duration(math.MaxInt64)-expirationDuration {
return time.Duration(math.MaxInt64)
}
return resetDuration + expirationDuration
}
func secondsToDuration(seconds uint64) (time.Duration, bool) {
const maxSeconds = math.MaxInt64 / int64(time.Second)
if seconds > uint64(maxSeconds) {
return time.Duration(math.MaxInt64), false
}
return time.Duration(seconds) * time.Second, true
}
================================================
FILE: middleware/limiter/limiter_test.go
================================================
package limiter
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/storage/memory"
)
type failingLimiterStorage struct {
data map[string][]byte
errs map[string]error
}
const testLimiterClientKey = "client-key"
func newFailingLimiterStorage() *failingLimiterStorage {
return &failingLimiterStorage{
data: make(map[string][]byte),
errs: make(map[string]error),
}
}
// countingFailStorage fails set operations after a specified number of successful calls
type countingFailStorage struct {
*failingLimiterStorage
setFailErr error
setCount int
failAfterN int
}
func newCountingFailStorage(failAfterN int, err error) *countingFailStorage {
return &countingFailStorage{
failingLimiterStorage: newFailingLimiterStorage(),
failAfterN: failAfterN,
setFailErr: err,
}
}
func (s *countingFailStorage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error {
s.setCount++
if s.setCount > s.failAfterN {
return s.setFailErr
}
return s.failingLimiterStorage.SetWithContext(ctx, key, val, exp)
}
type contextRecord struct {
key string
value string
canceled bool
}
type contextRecorderLimiterStorage struct {
*failingLimiterStorage
gets []contextRecord
sets []contextRecord
}
func sleepForRetryAfter(t *testing.T, resp *http.Response) {
t.Helper()
retryAfter := resp.Header.Get(fiber.HeaderRetryAfter)
if retryAfter == "" {
time.Sleep(500 * time.Millisecond)
return
}
seconds, err := strconv.Atoi(retryAfter)
require.NoError(t, err)
delay := time.Duration(seconds) * time.Second
// Sliding window needs roughly 2x the reported delay for the previous window to expire.
if doubled := 2 * delay; doubled > delay {
delay = doubled
}
if minDelay := 4 * time.Second; delay < minDelay {
delay = minDelay
}
time.Sleep(delay + 500*time.Millisecond)
}
func newContextRecorderLimiterStorage() *contextRecorderLimiterStorage {
return &contextRecorderLimiterStorage{failingLimiterStorage: newFailingLimiterStorage()}
}
func contextRecordFrom(ctx context.Context, key string) contextRecord {
record := contextRecord{
key: key,
canceled: errors.Is(ctx.Err(), context.Canceled),
}
if value, ok := ctx.Value(markerKey).(string); ok {
record.value = value
}
return record
}
func (s *contextRecorderLimiterStorage) GetWithContext(ctx context.Context, key string) ([]byte, error) {
s.gets = append(s.gets, contextRecordFrom(ctx, key))
return s.failingLimiterStorage.GetWithContext(ctx, key)
}
func (s *contextRecorderLimiterStorage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error {
s.sets = append(s.sets, contextRecordFrom(ctx, key))
return s.failingLimiterStorage.SetWithContext(ctx, key, val, exp)
}
func (s *contextRecorderLimiterStorage) recordedGets() []contextRecord {
out := make([]contextRecord, len(s.gets))
copy(out, s.gets)
return out
}
func (s *contextRecorderLimiterStorage) recordedSets() []contextRecord {
out := make([]contextRecord, len(s.sets))
copy(out, s.sets)
return out
}
func (s *failingLimiterStorage) GetWithContext(_ context.Context, key string) ([]byte, error) {
if err, ok := s.errs["get|"+key]; ok && err != nil {
return nil, err
}
if val, ok := s.data[key]; ok {
return append([]byte(nil), val...), nil
}
return nil, nil
}
func (s *failingLimiterStorage) Get(key string) ([]byte, error) {
return s.GetWithContext(context.Background(), key)
}
func (s *failingLimiterStorage) SetWithContext(_ context.Context, key string, val []byte, _ time.Duration) error {
if err, ok := s.errs["set|"+key]; ok && err != nil {
return err
}
s.data[key] = append([]byte(nil), val...)
return nil
}
func (s *failingLimiterStorage) Set(key string, val []byte, exp time.Duration) error {
return s.SetWithContext(context.Background(), key, val, exp)
}
func (*failingLimiterStorage) DeleteWithContext(context.Context, string) error { return nil }
func (*failingLimiterStorage) Delete(string) error { return nil }
func (*failingLimiterStorage) ResetWithContext(context.Context) error { return nil }
func (*failingLimiterStorage) Reset() error { return nil }
func (*failingLimiterStorage) Close() error { return nil }
type contextKey string
const markerKey contextKey = "marker"
func contextWithMarker(label string) context.Context {
return context.WithValue(context.Background(), markerKey, label)
}
func canceledContextWithMarker(label string) context.Context {
ctx, cancel := context.WithCancel(contextWithMarker(label))
cancel()
return ctx
}
func TestLimiterDefaultConfigNoPanic(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("ok")
})
require.NotPanics(t, func() {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
})
}
func TestLimiterFixedStorageGetError(t *testing.T) {
t.Parallel()
storage := newFailingLimiterStorage()
storage.errs["get|"+testLimiterClientKey] = errors.New("boom")
var captured error
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
captured = err
return c.Status(fiber.StatusInternalServerError).SendString("storage failure")
},
})
app.Use(New(Config{Storage: storage, Max: 1, Expiration: time.Second, KeyGenerator: func(fiber.Ctx) string { return testLimiterClientKey }}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("ok")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Error(t, captured)
require.ErrorContains(t, captured, "limiter: failed to get key")
require.ErrorContains(t, captured, "[redacted]")
}
func TestLimiterFixedStorageSetError(t *testing.T) {
t.Parallel()
storage := newFailingLimiterStorage()
storage.errs["set|"+testLimiterClientKey] = errors.New("boom")
var captured error
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
captured = err
return c.Status(fiber.StatusInternalServerError).SendString("storage failure")
},
})
app.Use(New(Config{Storage: storage, Max: 1, Expiration: time.Second, KeyGenerator: func(fiber.Ctx) string { return testLimiterClientKey }}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("ok")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Error(t, captured)
require.ErrorContains(t, captured, "limiter: failed to persist state")
require.ErrorContains(t, captured, "limiter: failed to store key")
require.ErrorContains(t, captured, "[redacted]")
}
func TestLimiterFixedPropagatesRequestContextToStorage(t *testing.T) {
t.Parallel()
storage := newContextRecorderLimiterStorage()
app := fiber.New()
app.Use(func(c fiber.Ctx) error {
path := c.Path()
if path == "/normal" {
c.SetContext(contextWithMarker("fixed-normal"))
}
if path == "/rollback" {
c.SetContext(canceledContextWithMarker("fixed-rollback"))
}
return c.Next()
})
app.Use(New(Config{
Storage: storage,
Max: 1,
Expiration: time.Minute,
SkipSuccessfulRequests: true,
KeyGenerator: func(c fiber.Ctx) string {
return c.Path()
},
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:mode", func(c fiber.Ctx) error {
return c.SendString("ok")
})
for _, path := range []string{"/normal", "/rollback"} {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, path, http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
gets := storage.recordedGets()
require.Len(t, gets, 4)
sets := storage.recordedSets()
require.Len(t, sets, 4)
verifyRecords := func(t *testing.T, records []contextRecord, key, wantValue string, wantCanceled bool) {
t.Helper()
var matched []contextRecord
for _, rec := range records {
if rec.key == key {
matched = append(matched, rec)
}
}
require.Len(t, matched, 2)
for _, rec := range matched {
require.Equal(t, wantValue, rec.value)
require.Equal(t, wantCanceled, rec.canceled)
}
}
verifyRecords(t, gets, "/normal", "fixed-normal", false)
verifyRecords(t, gets, "/rollback", "fixed-rollback", true)
verifyRecords(t, sets, "/normal", "fixed-normal", false)
verifyRecords(t, sets, "/rollback", "fixed-rollback", true)
}
func TestLimiterFixedStorageGetErrorDisableRedaction(t *testing.T) {
t.Parallel()
storage := newFailingLimiterStorage()
storage.errs["get|"+testLimiterClientKey] = errors.New("boom")
var captured error
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
captured = err
return c.Status(fiber.StatusInternalServerError).SendString("storage failure")
},
})
app.Use(New(Config{DisableValueRedaction: true, Storage: storage, Max: 1, Expiration: time.Second, KeyGenerator: func(fiber.Ctx) string { return testLimiterClientKey }}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("ok")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Error(t, captured)
require.ErrorContains(t, captured, testLimiterClientKey)
require.NotContains(t, captured.Error(), "[redacted]")
}
func TestLimiterFixedStorageSetErrorDisableRedaction(t *testing.T) {
t.Parallel()
storage := newFailingLimiterStorage()
storage.errs["set|"+testLimiterClientKey] = errors.New("boom")
var captured error
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
captured = err
return c.Status(fiber.StatusInternalServerError).SendString("storage failure")
},
})
app.Use(New(Config{DisableValueRedaction: true, Storage: storage, Max: 1, Expiration: time.Second, KeyGenerator: func(fiber.Ctx) string { return testLimiterClientKey }}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("ok")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Error(t, captured)
require.ErrorContains(t, captured, testLimiterClientKey)
require.NotContains(t, captured.Error(), "[redacted]")
}
func TestLimiterFixedStorageSetErrorOnSkipSuccessfulRequests(t *testing.T) {
t.Parallel()
storage := newCountingFailStorage(1, errors.New("second set failed"))
var captured error
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
captured = err
return c.Status(fiber.StatusInternalServerError).SendString("storage failure")
},
})
app.Use(New(Config{
Storage: storage,
Max: 10,
Expiration: time.Second,
SkipSuccessfulRequests: true,
KeyGenerator: func(fiber.Ctx) string { return testLimiterClientKey },
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Error(t, captured)
require.ErrorContains(t, captured, "limiter: failed to persist state")
}
func TestLimiterSlidingPropagatesRequestContextToStorage(t *testing.T) {
t.Parallel()
storage := newContextRecorderLimiterStorage()
app := fiber.New()
app.Use(func(c fiber.Ctx) error {
path := c.Path()
if path == "/normal" {
c.SetContext(contextWithMarker("sliding-normal"))
}
if path == "/rollback" {
c.SetContext(canceledContextWithMarker("sliding-rollback"))
}
return c.Next()
})
app.Use(New(Config{
Storage: storage,
Max: 1,
Expiration: time.Minute,
SkipSuccessfulRequests: true,
KeyGenerator: func(c fiber.Ctx) string {
return c.Path()
},
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:mode", func(c fiber.Ctx) error {
return c.SendString("ok")
})
for _, path := range []string{"/normal", "/rollback"} {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, path, http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
gets := storage.recordedGets()
require.Len(t, gets, 4)
sets := storage.recordedSets()
require.Len(t, sets, 4)
verifyRecords := func(t *testing.T, records []contextRecord, key, wantValue string, wantCanceled bool) {
t.Helper()
var matched []contextRecord
for _, rec := range records {
if rec.key == key {
matched = append(matched, rec)
}
}
require.Len(t, matched, 2)
for _, rec := range matched {
require.Equal(t, wantValue, rec.value)
require.Equal(t, wantCanceled, rec.canceled)
}
}
verifyRecords(t, gets, "/normal", "sliding-normal", false)
verifyRecords(t, gets, "/rollback", "sliding-rollback", true)
verifyRecords(t, sets, "/normal", "sliding-normal", false)
verifyRecords(t, sets, "/rollback", "sliding-rollback", true)
}
func TestLimiterSlidingSkipsPostUpdateWhenHeadersDisabled(t *testing.T) {
t.Parallel()
storage := newContextRecorderLimiterStorage()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: time.Second,
Storage: storage,
DisableHeaders: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Len(t, storage.recordedGets(), 1)
require.Len(t, storage.recordedSets(), 1)
}
// go test -run Test_Limiter_With_Max_Func_With_Zero -race -v
func Test_Limiter_With_Max_Func_With_Zero_And_Limiter_Sliding(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxFunc: func(_ fiber.Ctx) int { return 0 },
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
time.Sleep(4*time.Second + 500*time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
func Test_Limiter_Sliding_MaxFuncOverridesStaticMax(t *testing.T) {
t.Parallel()
app := fiber.New()
staticMax := 5
dynamicMax := 2
app.Use(New(Config{
Max: staticMax,
MaxFunc: func(fiber.Ctx) int { return dynamicMax },
Expiration: 2 * time.Second,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, strconv.Itoa(dynamicMax), resp.Header.Get("X-RateLimit-Limit"))
require.Equal(t, strconv.Itoa(dynamicMax-1), resp.Header.Get("X-RateLimit-Remaining"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, strconv.Itoa(dynamicMax), resp.Header.Get("X-RateLimit-Limit"))
require.Equal(t, strconv.Itoa(dynamicMax-2), resp.Header.Get("X-RateLimit-Remaining"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode)
}
// go test -run Test_Limiter_With_Max_Func_With_Zero -race -v
func Test_Limiter_With_Max_Func_With_Zero(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
MaxFunc: func(_ fiber.Ctx) int {
return 0
},
Expiration: 2 * time.Second,
Storage: memory.New(),
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello tester!")
})
var wg sync.WaitGroup
for i := 0; i <= 4; i++ {
wg.Go(func() {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "Hello tester!", string(body))
})
}
wg.Wait()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_With_Max_Func -race -v
func Test_Limiter_With_Max_Func(t *testing.T) {
t.Parallel()
app := fiber.New()
maxRequests := 10
app.Use(New(Config{
MaxFunc: func(_ fiber.Ctx) int {
return maxRequests
},
Expiration: 2 * time.Second,
Storage: memory.New(),
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello tester!")
})
var wg sync.WaitGroup
for i := 0; i <= maxRequests-1; i++ {
wg.Go(func() {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "Hello tester!", string(body))
})
}
wg.Wait()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_ExpirationFuncOverridesStaticExpiration -race -v
func Test_Limiter_Fixed_ExpirationFuncOverridesStaticExpiration(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 10 * time.Second,
ExpirationFunc: func(_ fiber.Ctx) time.Duration { return 2 * time.Second },
LimiterMiddleware: FixedWindow{},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_ExpirationFuncOverridesStaticExpiration -race -v
func Test_Limiter_Sliding_ExpirationFuncOverridesStaticExpiration(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 10 * time.Second,
ExpirationFunc: func(_ fiber.Ctx) time.Duration { return 2 * time.Second },
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode)
// Sliding window needs ~2x expiration to fully reset (considers previous window)
sleepForRetryAfter(t, resp)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_ExpirationFunc_FallbackOnZeroDuration -race -v
func Test_Limiter_Fixed_ExpirationFunc_FallbackOnZeroDuration(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
ExpirationFunc: func(_ fiber.Ctx) time.Duration { return 0 },
LimiterMiddleware: FixedWindow{},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_ExpirationFunc_FallbackOnNegativeDuration -race -v
func Test_Limiter_Fixed_ExpirationFunc_FallbackOnNegativeDuration(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
ExpirationFunc: func(_ fiber.Ctx) time.Duration { return -1 * time.Second },
LimiterMiddleware: FixedWindow{},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_ExpirationFunc_FallbackOnZeroDuration -race -v
func Test_Limiter_Sliding_ExpirationFunc_FallbackOnZeroDuration(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
ExpirationFunc: func(_ fiber.Ctx) time.Duration { return 0 },
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_ExpirationFunc_FallbackOnNegativeDuration -race -v
func Test_Limiter_Sliding_ExpirationFunc_FallbackOnNegativeDuration(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
ExpirationFunc: func(_ fiber.Ctx) time.Duration { return -1 * time.Second },
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode)
}
// go test -run Test_Limiter_Concurrency_Store -race -v
func Test_Limiter_Concurrency_Store(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 50,
Expiration: 2 * time.Second,
Storage: memory.New(),
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello tester!")
})
var wg sync.WaitGroup
for i := 0; i <= 49; i++ {
wg.Go(func() {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "Hello tester!", string(body))
})
}
wg.Wait()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Concurrency -race -v
func Test_Limiter_Concurrency(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 50,
Expiration: 2 * time.Second,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello tester!")
})
var wg sync.WaitGroup
for i := 0; i <= 49; i++ {
wg.Go(func() {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
assert.Equal(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "Hello tester!", string(body))
})
}
wg.Wait()
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_No_Skip_Choices -v
func Test_Limiter_Fixed_Window_No_Skip_Choices(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_No_Skip_Choices -v
func Test_Limiter_Fixed_Window_Custom_Storage_No_Skip_Choices(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
Storage: memory.New(),
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_No_Skip_Choices -v
func Test_Limiter_Sliding_Window_No_Skip_Choices(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
sleepForRetryAfter(t, resp)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_No_Skip_Choices -v
func Test_Limiter_Sliding_Window_Custom_Storage_No_Skip_Choices(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
Storage: memory.New(),
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
sleepForRetryAfter(t, resp)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
func Test_Limiter_Sliding_Window_RecalculatesAfterHandlerDelay(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: time.Second,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/", func(c fiber.Ctx) error {
time.Sleep(600 * time.Millisecond)
return c.SendStatus(fiber.StatusOK)
})
for i := 0; i < 2; i++ {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
time.Sleep(time.Second + 100*time.Millisecond)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "2", resp.Header.Get(xRateLimitLimit))
require.Equal(t, "1", resp.Header.Get(xRateLimitRemaining))
require.NotEmpty(t, resp.Header.Get(xRateLimitReset))
}
func Test_Limiter_Sliding_Window_ExpiresStalePrevHits(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: time.Second,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
time.Sleep(2500 * time.Millisecond)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "0", resp.Header.Get(xRateLimitRemaining))
}
func Test_Limiter_Sliding_Window_SkipFailedRequests_DecrementsPreviousWindow(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 200 * time.Millisecond,
SkipFailedRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:mode", func(c fiber.Ctx) error {
if c.Params("mode") == "fail" {
time.Sleep(300 * time.Millisecond)
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
type respErr struct {
resp *http.Response
err error
}
failCh := make(chan respErr, 1)
go func() {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
failCh <- respErr{resp: resp, err: err}
}()
time.Sleep(220 * time.Millisecond)
successResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/ok", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, successResp.StatusCode)
result := <-failCh
require.NoError(t, result.err)
require.Equal(t, fiber.StatusInternalServerError, result.resp.StatusCode)
require.Equal(t, "2", result.resp.Header.Get(xRateLimitLimit))
require.Equal(t, "1", result.resp.Header.Get(xRateLimitRemaining))
assert.NotEmpty(t, result.resp.Header.Get(xRateLimitReset))
}
// go test -run Test_Limiter_Fixed_Window_Skip_Failed_Requests -v
func Test_Limiter_Fixed_Window_Skip_Failed_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipFailedRequests: true,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_Skip_Failed_Requests -v
func Test_Limiter_Fixed_Window_Custom_Storage_Skip_Failed_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
Storage: memory.New(),
SkipFailedRequests: true,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Skip_Failed_Requests -v
func Test_Limiter_Sliding_Window_Skip_Failed_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipFailedRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
sleepForRetryAfter(t, resp)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_Skip_Failed_Requests -v
func Test_Limiter_Sliding_Window_Custom_Storage_Skip_Failed_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
Storage: memory.New(),
SkipFailedRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
sleepForRetryAfter(t, resp)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Skip_Successful_Requests -v
func Test_Limiter_Fixed_Window_Skip_Successful_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipSuccessfulRequests: true,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
}
// go test -run Test_Limiter_Fixed_Window_Custom_Storage_Skip_Successful_Requests -v
func Test_Limiter_Fixed_Window_Custom_Storage_Skip_Successful_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
Storage: memory.New(),
SkipSuccessfulRequests: true,
LimiterMiddleware: FixedWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Skip_Successful_Requests -v
func Test_Limiter_Sliding_Window_Skip_Successful_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipSuccessfulRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
sleepForRetryAfter(t, resp)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
}
// go test -run Test_Limiter_Sliding_Window_Custom_Storage_Skip_Successful_Requests -v
func Test_Limiter_Sliding_Window_Custom_Storage_Skip_Successful_Requests(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
Storage: memory.New(),
SkipSuccessfulRequests: true,
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/:status", func(c fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
sleepForRetryAfter(t, resp)
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody))
require.NoError(t, err)
require.Equal(t, 400, resp.StatusCode)
}
// go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4
func Benchmark_Limiter_Custom_Store(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Max: 100,
Expiration: 60 * time.Second,
Storage: memory.New(),
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
for b.Loop() {
h(fctx)
}
}
// Test to reproduce the bug where fiber.NewErrorf responses are not counted as failed requests
func Test_Limiter_Bug_NewErrorf_SkipSuccessfulRequests_SlidingWindow(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 60 * time.Second,
LimiterMiddleware: SlidingWindow{},
SkipSuccessfulRequests: true,
SkipFailedRequests: false,
DisableHeaders: true,
}))
app.Get("/", func(_ fiber.Ctx) error {
return fiber.NewErrorf(fiber.StatusInternalServerError, "Error")
})
// First request should succeed (and be counted because it's a failed request)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
// Second request should be rate limited because the first failed request was counted
// But currently this is not happening due to the bug
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
// This should be 429 (rate limited) but currently returns 500 due to the bug
require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode, "Second request should be rate limited")
}
// Test to reproduce the bug where fiber.NewErrorf responses are not counted as failed requests (FixedWindow)
func Test_Limiter_Bug_NewErrorf_SkipSuccessfulRequests_FixedWindow(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 60 * time.Second,
LimiterMiddleware: FixedWindow{},
SkipSuccessfulRequests: true,
SkipFailedRequests: false,
DisableHeaders: true,
}))
app.Get("/", func(_ fiber.Ctx) error {
return fiber.NewErrorf(fiber.StatusInternalServerError, "Error")
})
// First request should succeed (and be counted because it's a failed request)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
// Second request should be rate limited because the first failed request was counted
// But currently this is not happening due to the bug
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
// This should be 429 (rate limited) but currently returns 500 due to the bug
require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode, "Second request should be rate limited")
}
// go test -run Test_Limiter_Next
func Test_Limiter_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_Limiter_Headers(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 50,
Expiration: 2 * time.Second,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello tester!")
})
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
app.Handler()(fctx)
require.Equal(t, "50", string(fctx.Response.Header.Peek("X-RateLimit-Limit")))
if v := string(fctx.Response.Header.Peek("X-RateLimit-Remaining")); v == "" {
t.Error("The X-RateLimit-Remaining header is not set correctly - value is an empty string.")
}
if v := string(fctx.Response.Header.Peek("X-RateLimit-Reset")); (v != "1") && (v != "2") {
t.Error("The X-RateLimit-Reset header is not set correctly - value is out of bounds.")
}
}
func Test_Limiter_Disable_Headers(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
DisableHeaders: true,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello tester!")
})
// first request should pass
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
app.Handler()(fctx)
require.Equal(t, fiber.StatusOK, fctx.Response.StatusCode())
require.Equal(t, "Hello tester!", string(fctx.Response.Body()))
require.Empty(t, string(fctx.Response.Header.Peek("X-RateLimit-Limit")))
require.Empty(t, string(fctx.Response.Header.Peek("X-RateLimit-Remaining")))
require.Empty(t, string(fctx.Response.Header.Peek("X-RateLimit-Reset")))
// second request should hit the limit and return 429 without headers
fctx2 := &fasthttp.RequestCtx{}
fctx2.Request.Header.SetMethod(fiber.MethodGet)
fctx2.Request.SetRequestURI("/")
app.Handler()(fctx2)
require.Equal(t, fiber.StatusTooManyRequests, fctx2.Response.StatusCode())
require.Empty(t, string(fctx2.Response.Header.Peek(fiber.HeaderRetryAfter)))
require.Empty(t, string(fctx2.Response.Header.Peek("X-RateLimit-Limit")))
require.Empty(t, string(fctx2.Response.Header.Peek("X-RateLimit-Remaining")))
require.Empty(t, string(fctx2.Response.Header.Peek("X-RateLimit-Reset")))
}
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
func Benchmark_Limiter(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
Max: 100,
Expiration: 60 * time.Second,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI("/")
for b.Loop() {
h(fctx)
}
}
// go test -run Test_Sliding_Window -race -v
func Test_Sliding_Window(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Max: 10,
Expiration: 1 * time.Second,
Storage: memory.New(),
LimiterMiddleware: SlidingWindow{},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello tester!")
})
singleRequest := func(shouldFail bool) {
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
if shouldFail {
require.NoError(t, err)
require.Equal(t, 429, resp.StatusCode)
} else {
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
}
for range 5 {
singleRequest(false)
}
time.Sleep(3 * time.Second)
for range 5 {
singleRequest(false)
}
time.Sleep(3 * time.Second)
for range 5 {
singleRequest(false)
}
time.Sleep(3 * time.Second)
for range 10 {
singleRequest(false)
}
// requests should fail now
for range 5 {
singleRequest(true)
}
}
================================================
FILE: middleware/limiter/manager.go
================================================
package limiter
import (
"context"
"fmt"
"sync"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/memory"
)
// msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported
//
//go:generate msgp -o=manager_msgp.go -tests=false -unexported
type item struct {
currHits int
prevHits int
exp uint64
}
//msgp:ignore manager
type manager struct {
pool sync.Pool
memory *memory.Storage
storage fiber.Storage
redactKeys bool
}
const redactedKey = "[redacted]"
func newManager(storage fiber.Storage, redactKeys bool) *manager {
// Create new storage handler
manager := &manager{
pool: sync.Pool{
New: func() any {
return new(item)
},
},
redactKeys: redactKeys,
}
if storage != nil {
// Use provided storage if provided
manager.storage = storage
} else {
// Fallback too memory storage
manager.memory = memory.New()
}
return manager
}
// acquire returns an *entry from the sync.Pool
func (m *manager) acquire() *item {
return m.pool.Get().(*item) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
}
// release and reset *entry to sync.Pool
func (m *manager) release(e *item) {
e.prevHits = 0
e.currHits = 0
e.exp = 0
m.pool.Put(e)
}
// get data from storage or memory
func (m *manager) get(ctx context.Context, key string) (*item, error) {
if m.storage != nil {
raw, err := m.storage.GetWithContext(ctx, key)
if err != nil {
return nil, fmt.Errorf("limiter: failed to get key %q from storage: %w", m.logKey(key), err)
}
if raw != nil {
it := m.acquire()
if _, err := it.UnmarshalMsg(raw); err != nil {
m.release(it)
return nil, fmt.Errorf("limiter: failed to unmarshal key %q: %w", m.logKey(key), err)
}
return it, nil
}
return m.acquire(), nil
}
value := m.memory.Get(key)
if value == nil {
return m.acquire(), nil
}
it, ok := value.(*item)
if !ok {
return nil, fmt.Errorf("limiter: unexpected entry type %T for key %q", value, m.logKey(key))
}
return it, nil
}
// set data to storage or memory
func (m *manager) set(ctx context.Context, key string, it *item, exp time.Duration) error {
if m.storage != nil {
raw, err := it.MarshalMsg(nil)
if err != nil {
m.release(it)
return fmt.Errorf("limiter: failed to marshal key %q: %w", m.logKey(key), err)
}
if err := m.storage.SetWithContext(ctx, key, raw, exp); err != nil {
m.release(it)
return fmt.Errorf("limiter: failed to store key %q: %w", m.logKey(key), err)
}
m.release(it)
return nil
}
m.memory.Set(key, it, exp)
return nil
}
func (m *manager) logKey(key string) string {
if m.redactKeys {
return redactedKey
}
return key
}
================================================
FILE: middleware/limiter/manager_msgp.go
================================================
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
package limiter
import (
"github.com/tinylib/msgp/msgp"
)
// DecodeMsg implements msgp.Decodable
func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "currHits":
z.currHits, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "currHits")
return
}
case "prevHits":
z.prevHits, err = dc.ReadInt()
if err != nil {
err = msgp.WrapError(err, "prevHits")
return
}
case "exp":
z.exp, err = dc.ReadUint64()
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z item) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 3
// write "currHits"
err = en.Append(0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
if err != nil {
return
}
err = en.WriteInt(z.currHits)
if err != nil {
err = msgp.WrapError(err, "currHits")
return
}
// write "prevHits"
err = en.Append(0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
if err != nil {
return
}
err = en.WriteInt(z.prevHits)
if err != nil {
err = msgp.WrapError(err, "prevHits")
return
}
// write "exp"
err = en.Append(0xa3, 0x65, 0x78, 0x70)
if err != nil {
return
}
err = en.WriteUint64(z.exp)
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z item) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 3
// string "currHits"
o = append(o, 0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73)
o = msgp.AppendInt(o, z.currHits)
// string "prevHits"
o = append(o, 0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73)
o = msgp.AppendInt(o, z.prevHits)
// string "exp"
o = append(o, 0xa3, 0x65, 0x78, 0x70)
o = msgp.AppendUint64(o, z.exp)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "currHits":
z.currHits, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "currHits")
return
}
case "prevHits":
z.prevHits, bts, err = msgp.ReadIntBytes(bts)
if err != nil {
err = msgp.WrapError(err, "prevHits")
return
}
case "exp":
z.exp, bts, err = msgp.ReadUint64Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "exp")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z item) Msgsize() (s int) {
s = 1 + 9 + msgp.IntSize + 9 + msgp.IntSize + 4 + msgp.Uint64Size
return
}
================================================
FILE: middleware/limiter/manager_msgp_test.go
================================================
package limiter
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
import (
"bytes"
"testing"
"github.com/tinylib/msgp/msgp"
)
func TestMarshalUnmarshalitem(t *testing.T) {
v := item{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgitem(b *testing.B) {
v := item{}
b.ReportAllocs()
for b.Loop() {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgitem(b *testing.B) {
v := item{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
for b.Loop() {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshalitem(b *testing.B) {
v := item{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
for b.Loop() {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}
func TestEncodeDecodeitem(t *testing.T) {
v := item{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
m := v.Msgsize()
if buf.Len() > m {
t.Log("WARNING: TestEncodeDecodeitem Msgsize() is inaccurate")
}
vn := item{}
err := msgp.Decode(&buf, &vn)
if err != nil {
t.Error(err)
}
buf.Reset()
msgp.Encode(&buf, &v)
err = msgp.NewReader(&buf).Skip()
if err != nil {
t.Error(err)
}
}
func BenchmarkEncodeitem(b *testing.B) {
v := item{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
en := msgp.NewWriter(msgp.Nowhere)
b.ReportAllocs()
for b.Loop() {
v.EncodeMsg(en)
}
en.Flush()
}
func BenchmarkDecodeitem(b *testing.B) {
v := item{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
rd := msgp.NewEndlessReader(buf.Bytes(), b)
dc := msgp.NewReader(rd)
b.ReportAllocs()
for b.Loop() {
err := v.DecodeMsg(dc)
if err != nil {
b.Fatal(err)
}
}
}
================================================
FILE: middleware/logger/config.go
================================================
package logger
import (
"io"
"os"
"time"
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// Stream is a writer where logs are written
//
// Default: os.Stdout
Stream io.Writer
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// Skip is a function to determine if logging is skipped or written to Stream.
//
// Optional. Default: nil
Skip func(c fiber.Ctx) bool
// Done is a function that is called after the log string for a request is written to Output,
// and pass the log string as parameter.
//
// Optional. Default: nil
Done func(c fiber.Ctx, logString []byte)
// tagFunctions defines the custom tag action
//
// Optional. Default: map[string]LogFunc
CustomTags map[string]LogFunc
// You can define specific things before returning the handler: colors, template, etc.
//
// Optional. Default: beforeHandlerFunc
BeforeHandlerFunc func(*Config)
// You can use custom loggers with Fiber by using this field.
// This field is really useful if you're using Zerolog, Zap, Logrus, apex/log etc.
// If you don't define anything for this field, it'll use default logger of Fiber.
//
// Optional. Default: defaultLogger
LoggerFunc func(c fiber.Ctx, data *Data, cfg *Config) error
timeZoneLocation *time.Location
// Format defines the logging format for the middleware.
//
// You can customize the log output by defining a format string with placeholders
// such as: ${time}, ${ip}, ${status}, ${method}, ${path}, ${latency}, ${error}, etc.
// The full list of available placeholders can be found in 'tags.go' or at
// 'https://docs.gofiber.io/api/middleware/logger/#constants'.
//
// Fiber provides predefined logging formats that can be used directly:
//
// - DefaultFormat → Uses the default log format: "[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}"
// - CommonFormat → Uses the Apache Common Log Format (CLF): "${ip} - - [${time}] \"${method} ${url} ${protocol}\" ${status} ${bytesSent}\n"
// - CombinedFormat → Uses the Apache Combined Log Format: "${ip} - - [${time}] \"${method} ${url} ${protocol}\" ${status} ${bytesSent} \"${referer}\" \"${ua}\"\n"
// - JSONFormat → Uses the JSON log format: "{\"time\":\"${time}\",\"ip\":\"${ip}\",\"method\":\"${method}\",\"url\":\"${url}\",\"status\":${status},\"bytesSent\":${bytesSent}}\n"
// - ECSFormat → Uses the Elastic Common Schema (ECS) log format: {\"@timestamp\":\"${time}\",\"ecs\":{\"version\":\"1.6.0\"},\"client\":{\"ip\":\"${ip}\"},\"http\":{\"request\":{\"method\":\"${method}\",\"url\":\"${url}\",\"protocol\":\"${protocol}\"},\"response\":{\"status_code\":${status},\"body\":{\"bytes\":${bytesSent}}}},\"log\":{\"level\":\"INFO\",\"logger\":\"fiber\"},\"message\":\"${method} ${url} responded with ${status}\"}"
// If both `Format` and `CustomFormat` are provided, the `CustomFormat` will be used, and the `Format` field will be ignored.
// If no format is specified, the default format is used:
// "[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}"
Format string
// TimeFormat https://programming.guide/go/format-parse-string-time-date-example.html
//
// Optional. Default: 15:04:05
TimeFormat string
// TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc
//
// Optional. Default: "Local"
TimeZone string
// TimeInterval is the delay before the timestamp is updated
//
// Optional. Default: 500 * time.Millisecond
TimeInterval time.Duration
// DisableColors defines if the logs output should be colorized
//
// Default: false
DisableColors bool
// ForceColors forces the colors to be enabled even if the output is not a terminal
//
// Default: false
ForceColors bool
enableColors bool
enableLatency bool
}
const (
startTag = "${"
endTag = "}"
paramSeparator = ":"
)
// Buffer abstracts the buffer operations used when rendering log entries.
type Buffer interface {
Len() int
ReadFrom(r io.Reader) (int64, error)
WriteTo(w io.Writer) (int64, error)
Bytes() []byte
Write(p []byte) (int, error)
WriteByte(c byte) error
WriteString(s string) (int, error)
Set(p []byte)
SetString(s string)
String() string
}
// LogFunc formats logging output using the provided buffer and request data.
type LogFunc func(output Buffer, c fiber.Ctx, data *Data, extraParam string) (int, error)
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Skip: nil,
Done: nil,
Format: DefaultFormat,
TimeFormat: "15:04:05",
TimeZone: "Local",
TimeInterval: 500 * time.Millisecond,
Stream: os.Stdout,
BeforeHandlerFunc: beforeHandlerFunc,
LoggerFunc: defaultLoggerInstance,
enableColors: true,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Skip == nil {
cfg.Skip = ConfigDefault.Skip
}
if cfg.Done == nil {
cfg.Done = ConfigDefault.Done
}
if cfg.Format == "" {
cfg.Format = ConfigDefault.Format
}
if cfg.TimeZone == "" {
cfg.TimeZone = ConfigDefault.TimeZone
}
if cfg.TimeFormat == "" {
cfg.TimeFormat = ConfigDefault.TimeFormat
}
if int(cfg.TimeInterval) <= 0 {
cfg.TimeInterval = ConfigDefault.TimeInterval
}
if cfg.Stream == nil {
cfg.Stream = ConfigDefault.Stream
}
if cfg.BeforeHandlerFunc == nil {
cfg.BeforeHandlerFunc = ConfigDefault.BeforeHandlerFunc
}
if cfg.LoggerFunc == nil {
cfg.LoggerFunc = ConfigDefault.LoggerFunc
}
// Enable colors if no custom format or output is given
if (!cfg.DisableColors && cfg.Stream == ConfigDefault.Stream) || cfg.ForceColors {
cfg.enableColors = true
}
return cfg
}
================================================
FILE: middleware/logger/data.go
================================================
package logger
import (
"sync/atomic"
"time"
)
// Data is a struct to define some variables to use in custom logger function.
type Data struct {
Start time.Time
Stop time.Time
ChainErr error
Timestamp atomic.Value
Pid string
ErrPaddingStr string
TemplateChain [][]byte
LogFuncChain []LogFunc
}
================================================
FILE: middleware/logger/default_logger.go
================================================
package logger
import (
"fmt"
"io"
"os"
"strconv"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
"github.com/mattn/go-colorable"
"github.com/mattn/go-isatty"
"github.com/valyala/bytebufferpool"
)
// default logger for fiber
func defaultLoggerInstance(c fiber.Ctx, data *Data, cfg *Config) error {
if cfg == nil {
cfg = &Config{
Stream: os.Stdout,
Format: DefaultFormat,
enableColors: true,
}
}
// Check if Skip is defined and call it.
// Now, if Skip(c) == true, we SKIP logging:
if cfg.Skip != nil && cfg.Skip(c) {
return nil // Skip logging if Skip returns true
}
// Alias colors
colors := c.App().Config().ColorScheme
// Get new buffer
buf := bytebufferpool.Get()
// Default output when no custom Format or io.Writer is given
if cfg.Format == DefaultFormat {
// Format error if exist
formatErr := ""
if cfg.enableColors {
if data.ChainErr != nil {
formatErr = colors.Red + " | " + data.ChainErr.Error() + colors.Reset
}
fmt.Fprintf(buf,
"%s |%s %3d %s| %13v | %15s |%s %-7s %s| %-"+data.ErrPaddingStr+"s %s\n",
data.Timestamp.Load().(string), //nolint:forcetypeassert,errcheck // Timestamp is always a string
statusColor(c.Response().StatusCode(), &colors), c.Response().StatusCode(), colors.Reset,
data.Stop.Sub(data.Start),
c.IP(),
methodColor(c.Method(), &colors), c.Method(), colors.Reset,
c.Path(),
formatErr,
)
} else {
if data.ChainErr != nil {
formatErr = " | " + data.ChainErr.Error()
}
// Helper function to append fixed-width string with padding
fixedWidth := func(s string, width int, rightAlign bool) {
if rightAlign {
for i := len(s); i < width; i++ {
buf.WriteByte(' ')
}
buf.WriteString(s)
} else {
buf.WriteString(s)
for i := len(s); i < width; i++ {
buf.WriteByte(' ')
}
}
}
// Timestamp
buf.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert,errcheck // Timestamp is always a string
buf.WriteString(" | ")
// Status Code with 3 fixed width, right aligned
fixedWidth(strconv.Itoa(c.Response().StatusCode()), 3, true)
buf.WriteString(" | ")
// Duration with 13 fixed width, right aligned
fixedWidth(data.Stop.Sub(data.Start).String(), 13, true)
buf.WriteString(" | ")
// Client IP with 15 fixed width, right aligned
fixedWidth(c.IP(), 15, true)
buf.WriteString(" | ")
// HTTP Method with 7 fixed width, left aligned
fixedWidth(c.Method(), 7, false)
buf.WriteString(" | ")
// Path with dynamic padding for error message, left aligned
errPadding, _ := strconv.Atoi(data.ErrPaddingStr) //nolint:errcheck // It is fine to ignore the error
fixedWidth(c.Path(), errPadding, false)
// Error message
buf.WriteString(" ")
buf.WriteString(formatErr)
buf.WriteString("\n")
}
// Write buffer to output
writeLog(cfg.Stream, buf.Bytes())
if cfg.Done != nil {
cfg.Done(c, buf.Bytes())
}
// Put buffer back to pool
bytebufferpool.Put(buf)
// End chain
return nil
}
var err error
// Loop over template parts execute dynamic parts and add fixed parts to the buffer
for i, logFunc := range data.LogFuncChain {
switch {
case logFunc == nil:
buf.Write(data.TemplateChain[i])
case data.TemplateChain[i] == nil:
_, err = logFunc(buf, c, data, "")
default:
_, err = logFunc(buf, c, data, utils.UnsafeString(data.TemplateChain[i]))
}
if err != nil {
break
}
}
// Also write errors to the buffer
if err != nil {
buf.WriteString(err.Error())
}
writeLog(cfg.Stream, buf.Bytes())
if cfg.Done != nil {
cfg.Done(c, buf.Bytes())
}
// Put buffer back to pool
bytebufferpool.Put(buf)
return nil
}
// run something before returning the handler
func beforeHandlerFunc(cfg *Config) {
if cfg == nil {
return
}
// If colors are enabled, check terminal compatibility
if cfg.enableColors && cfg.Stream == os.Stdout {
cfg.Stream = colorable.NewColorableStdout()
if !cfg.ForceColors && (os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd()))) {
cfg.Stream = colorable.NewNonColorable(os.Stdout)
}
}
}
func appendInt(output Buffer, v int) (int, error) {
old := output.Len()
output.Set(strconv.AppendInt(output.Bytes(), int64(v), 10))
return output.Len() - old, nil
}
// writeLog writes a msg to w, printing a warning to stderr if the log fails.
func writeLog(w io.Writer, msg []byte) {
if _, err := w.Write(msg); err != nil {
// Write error to output
if _, writeErr := w.Write([]byte(err.Error())); writeErr != nil {
// There is something wrong with the given io.Writer
_, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", writeErr)
}
}
}
================================================
FILE: middleware/logger/errors.go
================================================
package logger
import (
"errors"
)
// ErrTemplateParameterMissing indicates that a template parameter was referenced but not provided.
var ErrTemplateParameterMissing = errors.New("logger: template parameter missing")
================================================
FILE: middleware/logger/format.go
================================================
package logger
const (
// Fiber's default logger
DefaultFormat = "[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}\n"
// Apache Common Log Format (CLF)
CommonFormat = "${ip} - - [${time}] \"${method} ${url} ${protocol}\" ${status} ${bytesSent}\n"
// Apache Combined Log Format
CombinedFormat = "${ip} - - [${time}] \"${method} ${url} ${protocol}\" ${status} ${bytesSent} \"${referer}\" \"${ua}\"\n"
// JSON log formats
JSONFormat = "{\"time\":\"${time}\",\"ip\":\"${ip}\",\"method\":\"${method}\",\"url\":\"${url}\",\"status\":${status},\"bytesSent\":${bytesSent}}\n"
// Elastic Common Schema (ECS) Log Format
ECSFormat = "{\"@timestamp\":\"${time}\",\"ecs\":{\"version\":\"1.6.0\"},\"client\":{\"ip\":\"${ip}\"},\"http\":{\"request\":{\"method\":\"${method}\",\"url\":\"${url}\",\"protocol\":\"${protocol}\"},\"response\":{\"status_code\":${status},\"body\":{\"bytes\":${bytesSent}}}},\"log\":{\"level\":\"INFO\",\"logger\":\"fiber\"},\"message\":\"${method} ${url} responded with ${status}\"}\n"
)
================================================
FILE: middleware/logger/logger.go
================================================
package logger
import (
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v3"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Get timezone location
tz, err := time.LoadLocation(cfg.TimeZone)
if err != nil || tz == nil {
cfg.timeZoneLocation = time.Local
} else {
cfg.timeZoneLocation = tz
}
// Check if format contains latency
cfg.enableLatency = strings.Contains(cfg.Format, "${"+TagLatency+"}")
var timestamp atomic.Value
// Create correct timeformat
timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat))
// Update date/time every 500 milliseconds in a separate go routine
if strings.Contains(cfg.Format, "${"+TagTime+"}") {
go func() {
for {
time.Sleep(cfg.TimeInterval)
timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat))
}
}()
}
// Set PID once
pid := strconv.Itoa(os.Getpid())
// Set variables
var (
once sync.Once
errHandler fiber.ErrorHandler
dataPool = sync.Pool{New: func() any { return new(Data) }}
)
// Err padding
errPadding := 15
errPaddingStr := strconv.Itoa(errPadding)
// Before handling func
cfg.BeforeHandlerFunc(&cfg)
// Logger data
// instead of analyzing the template inside(handler) each time, this is done once before
// and we create several slices of the same length with the functions to be executed and fixed parts.
templateChain, logFunChain, err := buildLogFuncChain(&cfg, createTagMap(&cfg))
if err != nil {
panic(err)
}
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Set error handler once
once.Do(func() {
// get longest possible path
stack := c.App().Stack()
for m := range stack {
for r := range stack[m] {
if len(stack[m][r].Path) > errPadding {
errPadding = len(stack[m][r].Path)
errPaddingStr = strconv.Itoa(errPadding)
}
}
}
// override error handler
errHandler = c.App().ErrorHandler
})
// Logger data
data := dataPool.Get().(*Data) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
// no need for a reset, as long as we always override everything
data.Pid = pid
data.ErrPaddingStr = errPaddingStr
data.Timestamp = timestamp
data.TemplateChain = templateChain
data.LogFuncChain = logFunChain
// put data back in the pool
defer dataPool.Put(data)
// Set latency start time
if cfg.enableLatency {
data.Start = time.Now()
}
// Handle request, store err for logging
chainErr := c.Next()
data.ChainErr = chainErr
// Manually call error handler
if chainErr != nil {
if err := errHandler(c, chainErr); err != nil {
_ = c.SendStatus(fiber.StatusInternalServerError) //nolint:errcheck // TODO: Explain why we ignore the error here
}
}
// Set latency stop time
if cfg.enableLatency {
data.Stop = time.Now()
}
// Logger instance & update some logger data fields
return cfg.LoggerFunc(c, data, &cfg)
}
}
================================================
FILE: middleware/logger/logger_test.go
================================================
//nolint:depguard // Because we test logging :D
package logger
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
"os"
"regexp"
"runtime"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v3"
fiberlog "github.com/gofiber/fiber/v3/log"
"github.com/gofiber/fiber/v3/middleware/requestid"
)
const (
pathFooBar = "/?foo=bar"
httpProto = "HTTP/1.1"
)
func benchmarkSetup(b *testing.B, app *fiber.App, uri string) {
b.Helper()
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI(uri)
b.ReportAllocs()
for b.Loop() {
h(fctx)
}
}
func benchmarkSetupParallel(b *testing.B, app *fiber.App, path string) {
b.Helper()
handler := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(fiber.MethodGet)
fctx.Request.SetRequestURI(path)
for pb.Next() {
handler(fctx)
}
})
}
// go test -run Test_Logger
func Test_Logger(t *testing.T) {
t.Parallel()
app := fiber.New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Use(New(Config{
Format: "${error}",
Stream: buf,
}))
app.Get("/", func(_ fiber.Ctx) error {
return errors.New("some random error")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Equal(t, "some random error", buf.String())
}
// go test -run Test_Logger_locals
func Test_Logger_locals(t *testing.T) {
t.Parallel()
app := fiber.New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Use(New(Config{
Format: "${locals:demo}",
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Locals("demo", "johndoe")
return c.SendStatus(fiber.StatusOK)
})
app.Get("/int", func(c fiber.Ctx) error {
c.Locals("demo", 55)
return c.SendStatus(fiber.StatusOK)
})
app.Get("/empty", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "johndoe", buf.String())
buf.Reset()
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/int", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "55", buf.String())
buf.Reset()
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/empty", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Empty(t, buf.String())
}
// go test -run Test_Logger_Next
func Test_Logger_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
// go test -run Test_Logger_Done
func Test_Logger_Done(t *testing.T) {
t.Parallel()
buf := bytes.NewBuffer(nil)
app := fiber.New()
app.Use(New(Config{
Done: func(c fiber.Ctx, logString []byte) {
if c.Response().StatusCode() == fiber.StatusOK {
_, err := buf.Write(logString)
require.NoError(t, err)
}
},
})).Get("/logging", func(ctx fiber.Ctx) error {
return ctx.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/logging", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Positive(t, buf.Len(), 0)
}
// Test_Logger_Filter tests the Filter functionality of the logger middleware.
// It verifies that logs are written or skipped based on the filter condition.
func Test_Logger_Filter(t *testing.T) {
t.Parallel()
t.Run("Test Not Found", func(t *testing.T) {
t.Parallel()
app := fiber.New()
logOutput := bytes.Buffer{}
// Return true to skip logging for all requests != 404
app.Use(New(Config{
Skip: func(c fiber.Ctx) bool {
return c.Response().StatusCode() != fiber.StatusNotFound
},
Stream: &logOutput,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/nonexistent", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
// Expect logs for the 404 request
require.Contains(t, logOutput.String(), "404")
})
t.Run("Test OK", func(t *testing.T) {
t.Parallel()
app := fiber.New()
logOutput := bytes.Buffer{}
// Return true to skip logging for all requests == 200
app.Use(New(Config{
Skip: func(c fiber.Ctx) bool {
return c.Response().StatusCode() == fiber.StatusOK
},
Stream: &logOutput,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// We skip logging for status == 200, so "200" should not appear
require.NotContains(t, logOutput.String(), "200")
})
t.Run("Always Skip", func(t *testing.T) {
t.Parallel()
app := fiber.New()
logOutput := bytes.Buffer{}
// Filter always returns true => skip all logs
app.Use(New(Config{
Skip: func(_ fiber.Ctx) bool {
return true // always skip
},
Stream: &logOutput,
}))
app.Get("/something", func(c fiber.Ctx) error {
return c.Status(fiber.StatusTeapot).SendString("I'm a teapot")
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/something", http.NoBody))
require.NoError(t, err)
// Expect NO logs
require.Empty(t, logOutput.String())
})
t.Run("Never Skip", func(t *testing.T) {
t.Parallel()
app := fiber.New()
logOutput := bytes.Buffer{}
// Filter always returns false => never skip logs
app.Use(New(Config{
Skip: func(_ fiber.Ctx) bool {
return false // never skip
},
Stream: &logOutput,
}))
app.Get("/always", func(c fiber.Ctx) error {
return c.Status(fiber.StatusTeapot).SendString("Teapot again")
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/always", http.NoBody))
require.NoError(t, err)
// Expect some logging - check any substring
require.Contains(t, logOutput.String(), strconv.Itoa(fiber.StatusTeapot))
})
t.Run("Skip /healthz", func(t *testing.T) {
t.Parallel()
app := fiber.New()
logOutput := bytes.Buffer{}
// Filter returns true (skip logs) if the request path is /healthz
app.Use(New(Config{
Skip: func(c fiber.Ctx) bool {
return c.Path() == "/healthz"
},
Stream: &logOutput,
}))
// Normal route
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello World!")
})
// Health route
app.Get("/healthz", func(c fiber.Ctx) error {
return c.SendString("OK")
})
// Request to "/" -> should be logged
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Contains(t, logOutput.String(), "200")
// Reset output buffer
logOutput.Reset()
// Request to "/healthz" -> should be skipped
_, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/healthz", http.NoBody))
require.NoError(t, err)
require.Empty(t, logOutput.String())
})
}
// go test -run Test_Logger_ErrorTimeZone
func Test_Logger_ErrorTimeZone(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
TimeZone: "invalid",
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
// go test -run Test_Logger_Fiber_Logger
func Test_Logger_LoggerToWriter(t *testing.T) {
app := fiber.New()
buf := bytebufferpool.Get()
t.Cleanup(func() {
bytebufferpool.Put(buf)
})
logger := fiberlog.DefaultLogger[*log.Logger]()
stdlogger := logger.Logger()
stdlogger.SetFlags(0)
logger.SetOutput(buf)
testCases := []struct {
levelStr string
level fiberlog.Level
}{
{
level: fiberlog.LevelTrace,
levelStr: "Trace",
},
{
level: fiberlog.LevelDebug,
levelStr: "Debug",
},
{
level: fiberlog.LevelInfo,
levelStr: "Info",
},
{
level: fiberlog.LevelWarn,
levelStr: "Warn",
},
{
level: fiberlog.LevelError,
levelStr: "Error",
},
}
for _, tc := range testCases {
level := strconv.Itoa(int(tc.level))
t.Run(level, func(t *testing.T) {
buf.Reset()
app.Use("/"+level, New(Config{
Format: "${error}",
Stream: LoggerToWriter(logger, tc.
level),
}))
app.Get("/"+level, func(_ fiber.Ctx) error {
return errors.New("some random error")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/"+level, http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Equal(t, "["+tc.levelStr+"] some random error\n", buf.String())
})
require.Panics(t, func() {
LoggerToWriter(logger, fiberlog.LevelPanic)
})
require.Panics(t, func() {
LoggerToWriter(logger, fiberlog.LevelFatal)
})
require.Panics(t, func() {
LoggerToWriter[any](nil, fiberlog.LevelFatal)
})
}
}
type fakeErrorOutput int
func (o *fakeErrorOutput) Write([]byte) (int, error) {
*o++
return 0, errors.New("fake output")
}
// go test -run Test_Logger_ErrorOutput_WithoutColor
func Test_Logger_ErrorOutput_WithoutColor(t *testing.T) {
t.Parallel()
o := new(fakeErrorOutput)
app := fiber.New()
app.Use(New(Config{
Stream: o,
DisableColors: true,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
require.EqualValues(t, 2, *o)
}
// go test -run Test_Logger_ErrorOutput
func Test_Logger_ErrorOutput(t *testing.T) {
t.Parallel()
o := new(fakeErrorOutput)
app := fiber.New()
app.Use(New(Config{
Stream: o,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
require.EqualValues(t, 2, *o)
}
// go test -run Test_Logger_All
func Test_Logger_All(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${pid}${reqHeaders}${referer}${scheme}${protocol}${ip}${ips}${host}${url}${ua}${body}${route}${black}${red}${green}${yellow}${blue}${magenta}${cyan}${white}${reset}${error}${reqHeader:test}${query:test}${form:test}${cookie:test}${non}",
Stream: buf,
}))
// Alias colors
colors := app.Config().ColorScheme
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, pathFooBar, http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
expected := fmt.Sprintf("%dHost=example.comhttpHTTP/1.10.0.0.0example.com/?foo=bar/%s%s%s%s%s%s%s%s%sNot Found", os.Getpid(), colors.Black, colors.Red, colors.Green, colors.Yellow, colors.Blue, colors.Magenta, colors.Cyan, colors.White, colors.Reset)
require.Equal(t, expected, buf.String())
}
func Test_Logger_CLF_Format(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: CommonFormat,
Stream: buf,
}))
method := fiber.MethodGet
status := fiber.StatusNotFound
bytesSent := 0
resp, err := app.Test(httptest.NewRequest(method, pathFooBar, http.NoBody))
require.NoError(t, err)
require.Equal(t, status, resp.StatusCode)
pattern := fmt.Sprintf(`0\.0\.0\.0 - - \[\d{2}:\d{2}:\d{2}\] "%s %s %s" %d %d`, method, regexp.QuoteMeta(pathFooBar), httpProto, status, bytesSent)
require.Regexp(t, pattern, buf.String())
}
func Test_Logger_Combined_CLF_Format(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: CombinedFormat,
Stream: buf,
}))
method := fiber.MethodGet
status := fiber.StatusNotFound
bytesSent := 0
referer := "http://example.com"
ua := "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/74.0.3729.169 Safari/537.36"
req := httptest.NewRequest(method, pathFooBar, http.NoBody)
req.Header.Set("Referer", referer)
req.Header.Set("User-Agent", ua)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, status, resp.StatusCode)
pattern := fmt.Sprintf(`0\.0\.0\.0 - - \[\d{2}:\d{2}:\d{2}\] "%s %s %s" %d %d "%s" "%s"`, method, regexp.QuoteMeta(pathFooBar), httpProto, status, bytesSent, regexp.QuoteMeta(referer), regexp.QuoteMeta(ua)) //nolint:gocritic // double quoting for regex and string is not needed
require.Regexp(t, pattern, buf.String())
}
func Test_Logger_Json_Format(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: JSONFormat,
Stream: buf,
}))
method := fiber.MethodGet
status := fiber.StatusNotFound
ip := "0.0.0.0"
bytesSent := 0
req := httptest.NewRequest(method, pathFooBar, http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, status, resp.StatusCode)
pattern := fmt.Sprintf(`\{"time":"\d{2}:\d{2}:\d{2}","ip":"%s","method":%q,"url":"%s","status":%d,"bytesSent":%d\}`, regexp.QuoteMeta(ip), method, regexp.QuoteMeta(pathFooBar), status, bytesSent) //nolint:gocritic // double quoting for regex and string is not needed
require.Regexp(t, pattern, buf.String())
}
func Test_Logger_ECS_Format(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: ECSFormat,
Stream: buf,
}))
method := fiber.MethodGet
status := fiber.StatusNotFound
ip := "0.0.0.0"
bytesSent := 0
msg := fmt.Sprintf("%s %s responded with %d", method, pathFooBar, status)
req := httptest.NewRequest(method, pathFooBar, http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, status, resp.StatusCode)
pattern := fmt.Sprintf(`\{"@timestamp":"\d{2}:\d{2}:\d{2}","ecs":\{"version":"1.6.0"\},"client":\{"ip":"%s"\},"http":\{"request":\{"method":%q,"url":"%s","protocol":%q\},"response":\{"status_code":%d,"body":\{"bytes":%d\}\}\},"log":\{"level":"INFO","logger":"fiber"\},"message":"%s"\}`, regexp.QuoteMeta(ip), method, regexp.QuoteMeta(pathFooBar), httpProto, status, bytesSent, regexp.QuoteMeta(msg)) //nolint:gocritic // double quoting for regex and string is not needed
require.Regexp(t, pattern, buf.String())
}
func getLatencyTimeUnits() []struct {
unit string
div time.Duration
} {
// windows does not support µs sleep precision
// https://github.com/golang/go/issues/29485
if runtime.GOOS == "windows" {
return []struct {
unit string
div time.Duration
}{
{unit: "ms", div: time.Millisecond},
{unit: "s", div: time.Second},
}
}
return []struct {
unit string
div time.Duration
}{
{unit: "µs", div: time.Microsecond},
{unit: "ms", div: time.Millisecond},
{unit: "s", div: time.Second},
}
}
// go test -run Test_Logger_WithLatency
func Test_Logger_WithLatency(t *testing.T) {
buff := bytebufferpool.Get()
defer bytebufferpool.Put(buff)
app := fiber.New()
logger := New(Config{
Stream: buff,
Format: "${latency}",
})
app.Use(logger)
// Define a list of time units to test
timeUnits := getLatencyTimeUnits()
// Initialize a new time unit
sleepDuration := 1 * time.Nanosecond
// Define a test route that sleeps
app.Get("/test", func(c fiber.Ctx) error {
time.Sleep(sleepDuration)
return c.SendStatus(fiber.StatusOK)
})
// Loop through each time unit and assert that the log output contains the expected latency value
for _, tu := range timeUnits {
// Update the sleep duration for the next iteration
sleepDuration = 1 * tu.div
// Create a new HTTP request to the test route
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{
Timeout: 3 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// Assert that the log output contains the expected latency value in the current time unit
require.True(t, bytes.HasSuffix(buff.Bytes(), []byte(tu.unit)), "Expected latency to be in %s, got %s", tu.unit, buff.String())
// Reset the buffer
buff.Reset()
}
}
// go test -run Test_Logger_WithLatency_DefaultFormat
func Test_Logger_WithLatency_DefaultFormat(t *testing.T) {
buff := bytebufferpool.Get()
defer bytebufferpool.Put(buff)
app := fiber.New()
logger := New(Config{
Stream: buff,
})
app.Use(logger)
// Define a list of time units to test
timeUnits := getLatencyTimeUnits()
// Initialize a new time unit
sleepDuration := 1 * time.Nanosecond
// Define a test route that sleeps
app.Get("/test", func(c fiber.Ctx) error {
time.Sleep(sleepDuration)
return c.SendStatus(fiber.StatusOK)
})
// Loop through each time unit and assert that the log output contains the expected latency value
for _, tu := range timeUnits {
// Update the sleep duration for the next iteration
sleepDuration = 1 * tu.div
// Create a new HTTP request to the test route
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// Assert that the log output contains the expected latency value in the current time unit
// parse out the latency value from the log output
latency := bytes.Split(buff.Bytes(), []byte(" | "))[2]
// Assert that the latency value is in the current time unit
require.True(t, bytes.HasSuffix(latency, []byte(tu.unit)), "Expected latency to be in %s, got %s", tu.unit, latency)
// Reset the buffer
buff.Reset()
}
}
// go test -run Test_Query_Params
func Test_Query_Params(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${queryParams}",
Stream: buf,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar&baz=moz", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
expected := "foo=bar&baz=moz"
require.Equal(t, expected, buf.String())
}
// go test -run Test_Response_Body
func Test_Response_Body(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${resBody}",
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Sample response body")
})
app.Post("/test", func(c fiber.Ctx) error {
return c.Send([]byte("Post in test"))
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
expectedGetResponse := "Sample response body"
require.Equal(t, expectedGetResponse, buf.String())
buf.Reset() // Reset buffer to test POST
_, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/test", http.NoBody))
expectedPostResponse := "Post in test"
require.NoError(t, err)
require.Equal(t, expectedPostResponse, buf.String())
}
// go test -run Test_Request_Body
func Test_Request_Body(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Stream: buf,
}))
app.Post("/", func(c fiber.Ctx) error {
c.Response().Header.SetContentLength(5)
return c.SendString("World")
})
// Create a POST request with a body
body := []byte("Hello")
req := httptest.NewRequest(fiber.MethodPost, "/", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/octet-stream")
_, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, "5 5 200", buf.String())
}
// go test -run Test_Logger_AppendUint
func Test_Logger_AppendUint(t *testing.T) {
t.Parallel()
app := fiber.New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("hello")
})
app.Get("/content", func(c fiber.Ctx) error {
c.Response().Header.SetContentLength(5)
return c.SendString("hello")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "-2 0 200", buf.String())
buf.Reset()
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/content", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "-2 5 200", buf.String())
}
// go test -run Test_Logger_Data_Race -race
func Test_Logger_Data_Race(t *testing.T) {
t.Parallel()
app := fiber.New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Use(New(ConfigDefault))
app.Use(New(Config{
Format: "${time} | ${pid} | ${locals:requestid} | ${status} | ${latency} | ${method} | ${path}\n",
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("hello")
})
var (
resp1, resp2 *http.Response
err1, err2 error
)
wg := &sync.WaitGroup{}
wg.Go(func() {
resp1, err1 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
})
resp2, err2 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
wg.Wait()
require.NoError(t, err1)
require.Equal(t, fiber.StatusOK, resp1.StatusCode)
require.NoError(t, err2)
require.Equal(t, fiber.StatusOK, resp2.StatusCode)
}
// go test -run Test_Response_Header
func Test_Response_Header(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(requestid.New(requestid.Config{
Next: nil,
Header: fiber.HeaderXRequestID,
Generator: func() string { return "Hello fiber!" },
}))
app.Use(New(Config{
Format: "${respHeader:X-Request-ID}",
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello fiber!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "Hello fiber!", buf.String())
}
// go test -run Test_Req_Header
func Test_Req_Header(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${reqHeader:test}",
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello fiber!")
})
headerReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
headerReq.Header.Add("test", "Hello fiber!")
resp, err := app.Test(headerReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "Hello fiber!", buf.String())
}
// go test -run Test_ReqHeader_Header
func Test_ReqHeader_Header(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${reqHeader:test}",
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello fiber!")
})
reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
reqHeaderReq.Header.Add("test", "Hello fiber!")
resp, err := app.Test(reqHeaderReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "Hello fiber!", buf.String())
}
// go test -run Test_CustomTags
func Test_CustomTags(t *testing.T) {
t.Parallel()
customTag := "it is a custom tag"
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${custom_tag}",
CustomTags: map[string]LogFunc{
"custom_tag": func(output Buffer, _ fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(customTag)
},
},
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello fiber!")
})
reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
reqHeaderReq.Header.Add("test", "Hello fiber!")
resp, err := app.Test(reqHeaderReq)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, customTag, buf.String())
}
// go test -run Test_Logger_ByteSent_Streaming
func Test_Logger_ByteSent_Streaming(t *testing.T) {
t.Parallel()
app := fiber.New()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Stream: buf,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.RequestCtx().SetBodyStreamWriter(func(w *bufio.Writer) {
var i int
for {
i++
msg := fmt.Sprintf("%d - the time is %v", i, time.Now())
fmt.Fprintf(w, "data: Message: %s\n\n", msg)
err := w.Flush()
if err != nil {
break
}
if i == 10 {
break
}
}
})
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// -2 means identity, -1 means chunked, 200 status
require.Equal(t, "-2 -1 200", buf.String())
}
type fakeOutput int
func (o *fakeOutput) Write(b []byte) (int, error) {
*o++
return len(b), nil
}
// go test -run Test_Logger_EnableColors
func Test_Logger_EnableColors(t *testing.T) {
t.Parallel()
o := new(fakeOutput)
app := fiber.New()
app.Use(New(Config{
Stream: o,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
require.EqualValues(t, 1, *o)
}
// go test -run Test_Logger_ForceColors
func Test_Logger_ForceColors(t *testing.T) {
t.Parallel()
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
app := fiber.New()
app.Use(New(Config{
Format: "${ip}${status}${method}${path}${error}\n",
Stream: buf,
DisableColors: true,
ForceColors: true,
}))
// Alias colors
colors := app.Config().ColorScheme
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
expected := fmt.Sprintf("0.0.0.0%s404%s%sGET%s/%sNot Found%s\n", colors.Yellow, colors.Reset, colors.Cyan, colors.Reset, colors.Red, colors.Reset)
require.Equal(t, expected, buf.String())
}
// go test -v -run=^$ -bench=Benchmark_Logger$ -benchmem -count=4
func Benchmark_Logger(b *testing.B) {
b.Run("NoMiddleware", func(bb *testing.B) {
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchmarkSetup(bb, app, "/")
})
b.Run("WithBytesAndStatus", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("test", "test")
return c.SendString("Hello, World!")
})
benchmarkSetup(bb, app, "/")
})
b.Run("DefaultFormat", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchmarkSetup(bb, app, "/")
})
b.Run("DefaultFormatDisableColors", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Stream: io.Discard,
DisableColors: true,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchmarkSetup(bb, app, "/")
})
b.Run("DefaultFormatForceColors", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Stream: io.Discard,
ForceColors: true,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchmarkSetup(bb, app, "/")
})
b.Run("DefaultFormatWithFiberLog", func(bb *testing.B) {
app := fiber.New()
logger := fiberlog.DefaultLogger[*log.Logger]()
logger.SetOutput(io.Discard)
app.Use(New(Config{
Stream: LoggerToWriter(logger, fiberlog.LevelDebug),
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchmarkSetup(bb, app, "/")
})
b.Run("WithTagParameter", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status} ${reqHeader:test}",
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("test", "test")
return c.SendString("Hello, World!")
})
benchmarkSetup(bb, app, "/")
})
b.Run("WithLocals", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${locals:demo}",
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Locals("demo", "johndoe")
return c.SendStatus(fiber.StatusOK)
})
benchmarkSetup(bb, app, "/")
})
b.Run("WithLocalsInt", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${locals:demo}",
Stream: io.Discard,
}))
app.Get("/int", func(c fiber.Ctx) error {
c.Locals("demo", 55)
return c.SendStatus(fiber.StatusOK)
})
benchmarkSetup(bb, app, "/int")
})
b.Run("WithCustomDone", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Done: func(c fiber.Ctx, logString []byte) {
if c.Response().StatusCode() == fiber.StatusOK {
io.Discard.Write(logString) //nolint:errcheck // ignore error
}
},
Stream: io.Discard,
}))
app.Get("/logging", func(ctx fiber.Ctx) error {
return ctx.SendStatus(fiber.StatusOK)
})
benchmarkSetup(bb, app, "/logging")
})
b.Run("WithAllTags", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${pid}${reqHeaders}${referer}${scheme}${protocol}${ip}${ips}${host}${url}${ua}${body}${route}${black}${red}${green}${yellow}${blue}${magenta}${cyan}${white}${reset}${error}${reqHeader:test}${query:test}${form:test}${cookie:test}${non}",
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchmarkSetup(bb, app, "/")
})
b.Run("Streaming", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.RequestCtx().SetBodyStreamWriter(func(w *bufio.Writer) {
var i int
for {
i++
msg := fmt.Sprintf("%d - the time is %v", i, time.Now())
fmt.Fprintf(w, "data: Message: %s\n\n", msg)
err := w.Flush()
if err != nil {
break
}
if i == 10 {
break
}
}
})
return nil
})
benchmarkSetup(bb, app, "/")
})
b.Run("WithBody", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${resBody}",
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Sample response body")
})
benchmarkSetup(bb, app, "/")
})
}
// go test -v -run=^$ -bench=Benchmark_Logger_Parallel$ -benchmem -count=4
func Benchmark_Logger_Parallel(b *testing.B) {
b.Run("NoMiddleware", func(bb *testing.B) {
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchmarkSetupParallel(bb, app, "/")
})
b.Run("WithBytesAndStatus", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("test", "test")
return c.SendString("Hello, World!")
})
benchmarkSetupParallel(bb, app, "/")
})
b.Run("DefaultFormat", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchmarkSetupParallel(bb, app, "/")
})
b.Run("DefaultFormatWithFiberLog", func(bb *testing.B) {
app := fiber.New()
logger := fiberlog.DefaultLogger[*log.Logger]()
logger.SetOutput(io.Discard)
app.Use(New(Config{
Stream: LoggerToWriter(logger, fiberlog.LevelDebug),
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchmarkSetupParallel(bb, app, "/")
})
b.Run("DefaultFormatDisableColors", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Stream: io.Discard,
DisableColors: true,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchmarkSetupParallel(bb, app, "/")
})
b.Run("DefaultFormatForceColors", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Stream: io.Discard,
ForceColors: true,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchmarkSetupParallel(bb, app, "/")
})
b.Run("WithTagParameter", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status} ${reqHeader:test}",
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("test", "test")
return c.SendString("Hello, World!")
})
benchmarkSetupParallel(bb, app, "/")
})
b.Run("WithLocals", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${locals:demo}",
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Locals("demo", "johndoe")
return c.SendStatus(fiber.StatusOK)
})
benchmarkSetupParallel(bb, app, "/")
})
b.Run("WithLocalsInt", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${locals:demo}",
Stream: io.Discard,
}))
app.Get("/int", func(c fiber.Ctx) error {
c.Locals("demo", 55)
return c.SendStatus(fiber.StatusOK)
})
benchmarkSetupParallel(bb, app, "/int")
})
b.Run("WithCustomDone", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Done: func(c fiber.Ctx, logString []byte) {
if c.Response().StatusCode() == fiber.StatusOK {
io.Discard.Write(logString) //nolint:errcheck // ignore error
}
},
Stream: io.Discard,
}))
app.Get("/logging", func(ctx fiber.Ctx) error {
return ctx.SendStatus(fiber.StatusOK)
})
benchmarkSetupParallel(bb, app, "/logging")
})
b.Run("WithAllTags", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${pid}${reqHeaders}${referer}${scheme}${protocol}${ip}${ips}${host}${url}${ua}${body}${route}${black}${red}${green}${yellow}${blue}${magenta}${cyan}${white}${reset}${error}${reqHeader:test}${query:test}${form:test}${cookie:test}${non}",
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
benchmarkSetupParallel(bb, app, "/")
})
b.Run("Streaming", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${bytesReceived} ${bytesSent} ${status}",
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.RequestCtx().SetBodyStreamWriter(func(w *bufio.Writer) {
var i int
for {
i++
msg := fmt.Sprintf("%d - the time is %v", i, time.Now())
fmt.Fprintf(w, "data: Message: %s\n\n", msg)
err := w.Flush()
if err != nil {
break
}
if i == 10 {
break
}
}
})
return nil
})
benchmarkSetupParallel(bb, app, "/")
})
b.Run("WithBody", func(bb *testing.B) {
app := fiber.New()
app.Use(New(Config{
Format: "${resBody}",
Stream: io.Discard,
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Sample response body")
})
benchmarkSetupParallel(bb, app, "/")
})
}
================================================
FILE: middleware/logger/tags.go
================================================
package logger
import (
"fmt"
"maps"
"strings"
"github.com/gofiber/fiber/v3"
)
// Logger variables
const (
TagPid = "pid"
TagTime = "time"
TagReferer = "referer"
TagProtocol = "protocol"
TagScheme = "scheme"
TagPort = "port"
TagIP = "ip"
TagIPs = "ips"
TagHost = "host"
TagMethod = "method"
TagPath = "path"
TagURL = "url"
TagUA = "ua"
TagLatency = "latency"
TagStatus = "status"
TagResBody = "resBody"
TagReqHeaders = "reqHeaders"
TagQueryStringParams = "queryParams"
TagBody = "body"
TagBytesSent = "bytesSent"
TagBytesReceived = "bytesReceived"
TagRoute = "route"
TagError = "error"
TagReqHeader = "reqHeader:"
TagRespHeader = "respHeader:"
TagLocals = "locals:"
TagQuery = "query:"
TagForm = "form:"
TagCookie = "cookie:"
TagBlack = "black"
TagRed = "red"
TagGreen = "green"
TagYellow = "yellow"
TagBlue = "blue"
TagMagenta = "magenta"
TagCyan = "cyan"
TagWhite = "white"
TagReset = "reset"
)
// createTagMap function merged the default with the custom tags
func createTagMap(cfg *Config) map[string]LogFunc {
// Set default tags
tagFunctions := map[string]LogFunc{
TagReferer: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.Get(fiber.HeaderReferer))
},
TagProtocol: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.Protocol())
},
TagScheme: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.Scheme())
},
TagPort: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.Port())
},
TagIP: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.IP())
},
TagIPs: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.Get(fiber.HeaderXForwardedFor))
},
TagHost: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.Hostname())
},
TagPath: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.Path())
},
TagURL: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.OriginalURL())
},
TagUA: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.Get(fiber.HeaderUserAgent))
},
TagBody: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.Write(c.Body())
},
TagBytesReceived: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return appendInt(output, c.Request().Header.ContentLength())
},
TagBytesSent: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return appendInt(output, c.Response().Header.ContentLength())
},
TagRoute: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.Route().Path)
},
TagResBody: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.Write(c.Response().Body())
},
TagReqHeaders: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
out := make(map[string][]string)
if err := c.Bind().Header(&out); err != nil {
return 0, err
}
reqHeaders := make([]string, 0, len(out))
for k, v := range out {
reqHeaders = append(reqHeaders, k+"="+strings.Join(v, ","))
}
return output.WriteString(strings.Join(reqHeaders, "&"))
},
TagQueryStringParams: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.Request().URI().QueryArgs().String())
},
TagBlack: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Black)
},
TagRed: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Red)
},
TagGreen: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Green)
},
TagYellow: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Yellow)
},
TagBlue: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Blue)
},
TagMagenta: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Magenta)
},
TagCyan: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Cyan)
},
TagWhite: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.White)
},
TagReset: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
return output.WriteString(c.App().Config().ColorScheme.Reset)
},
TagError: func(output Buffer, c fiber.Ctx, data *Data, _ string) (int, error) {
if data.ChainErr != nil {
if cfg.enableColors {
colors := c.App().Config().ColorScheme
return fmt.Fprintf(output, "%s%s%s", colors.Red, data.ChainErr.Error(), colors.Reset)
}
return output.WriteString(data.ChainErr.Error())
}
return output.WriteString("-")
},
TagReqHeader: func(output Buffer, c fiber.Ctx, _ *Data, extraParam string) (int, error) {
return output.WriteString(c.Get(extraParam))
},
TagRespHeader: func(output Buffer, c fiber.Ctx, _ *Data, extraParam string) (int, error) {
return output.WriteString(c.GetRespHeader(extraParam))
},
TagQuery: func(output Buffer, c fiber.Ctx, _ *Data, extraParam string) (int, error) {
return output.WriteString(fiber.Query[string](c, extraParam))
},
TagForm: func(output Buffer, c fiber.Ctx, _ *Data, extraParam string) (int, error) {
return output.WriteString(c.FormValue(extraParam))
},
TagCookie: func(output Buffer, c fiber.Ctx, _ *Data, extraParam string) (int, error) {
return output.WriteString(c.Cookies(extraParam))
},
TagLocals: func(output Buffer, c fiber.Ctx, _ *Data, extraParam string) (int, error) {
switch v := c.Locals(extraParam).(type) {
case []byte:
return output.Write(v)
case string:
return output.WriteString(v)
case nil:
return 0, nil
default:
return fmt.Fprintf(output, "%v", v)
}
},
TagStatus: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
if cfg.enableColors {
colors := c.App().Config().ColorScheme
return fmt.Fprintf(output, "%s%3d%s", statusColor(c.Response().StatusCode(), &colors), c.Response().StatusCode(), colors.Reset)
}
return appendInt(output, c.Response().StatusCode())
},
TagMethod: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) {
if cfg.enableColors {
colors := c.App().Config().ColorScheme
return fmt.Fprintf(output, "%s%s%s", methodColor(c.Method(), &colors), c.Method(), colors.Reset)
}
return output.WriteString(c.Method())
},
TagPid: func(output Buffer, _ fiber.Ctx, data *Data, _ string) (int, error) {
return output.WriteString(data.Pid)
},
TagLatency: func(output Buffer, _ fiber.Ctx, data *Data, _ string) (int, error) {
latency := data.Stop.Sub(data.Start)
return fmt.Fprintf(output, "%13v", latency)
},
TagTime: func(output Buffer, _ fiber.Ctx, data *Data, _ string) (int, error) {
return output.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert,errcheck // We always store a string in here
},
}
// merge with custom tags from user
maps.Copy(tagFunctions, cfg.CustomTags)
return tagFunctions
}
================================================
FILE: middleware/logger/template_chain.go
================================================
package logger
import (
"bytes"
"fmt"
"github.com/gofiber/utils/v2"
)
// buildLogFuncChain analyzes the template and creates slices with the functions for execution and
// slices with the fixed parts of the template and the parameters
//
// fixParts contains the fixed parts of the template or parameters if a function is stored in the funcChain at this position
// funcChain contains for the parts which exist the functions for the dynamic parts
// funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain,
// if a function exists for the part, a parameter for it can also exist in the fixParts slice
func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) ([][]byte, []LogFunc, error) {
// process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62
templateB := utils.UnsafeBytes(cfg.Format)
startTagB := utils.UnsafeBytes(startTag)
endTagB := utils.UnsafeBytes(endTag)
paramSeparatorB := utils.UnsafeBytes(paramSeparator)
var fixParts [][]byte
var funcChain []LogFunc
for {
before, after, found := bytes.Cut(templateB, startTagB)
if !found {
// no starting tag found in the existing template part
break
}
// add fixed part
funcChain = append(funcChain, nil)
fixParts = append(fixParts, before)
templateB = after
before, after, found = bytes.Cut(templateB, endTagB)
if !found {
// cannot find end tag - just write it to the output.
funcChain = append(funcChain, nil)
fixParts = append(fixParts, startTagB)
break
}
// ## function block ##
// first check for tags with parameters
tag, param, foundParam := bytes.Cut(before, paramSeparatorB)
if foundParam {
logFunc, ok := tagFunctions[utils.UnsafeString(tag)+paramSeparator]
if !ok {
return nil, nil, fmt.Errorf("%w: %q", ErrTemplateParameterMissing, utils.UnsafeString(before))
}
funcChain = append(funcChain, logFunc)
// add param to the fixParts
fixParts = append(fixParts, param)
} else if logFunc, ok := tagFunctions[utils.UnsafeString(before)]; ok {
// add functions without parameter
funcChain = append(funcChain, logFunc)
fixParts = append(fixParts, nil)
}
// ## function block end ##
// reduce the template string
templateB = after
}
// set the rest
funcChain = append(funcChain, nil)
fixParts = append(fixParts, templateB)
return fixParts, funcChain, nil
}
================================================
FILE: middleware/logger/utils.go
================================================
package logger
import (
"io"
"github.com/gofiber/fiber/v3"
fiberlog "github.com/gofiber/fiber/v3/log"
"github.com/gofiber/utils/v2"
)
func methodColor(method string, colors *fiber.Colors) string {
if colors == nil {
return ""
}
switch method {
case fiber.MethodGet:
return colors.Cyan
case fiber.MethodPost:
return colors.Green
case fiber.MethodPut:
return colors.Yellow
case fiber.MethodDelete:
return colors.Red
case fiber.MethodPatch:
return colors.White
case fiber.MethodHead:
return colors.Magenta
case fiber.MethodOptions:
return colors.Blue
default:
return colors.Reset
}
}
func statusColor(code int, colors *fiber.Colors) string {
if colors == nil {
return ""
}
switch {
case code >= fiber.StatusOK && code < fiber.StatusMultipleChoices:
return colors.Green
case code >= fiber.StatusMultipleChoices && code < fiber.StatusBadRequest:
return colors.Blue
case code >= fiber.StatusBadRequest && code < fiber.StatusInternalServerError:
return colors.Yellow
default:
return colors.Red
}
}
type customLoggerWriter[T any] struct {
loggerInstance fiberlog.AllLogger[T]
level fiberlog.Level
}
// Write implements io.Writer and forwards the payload to the configured logger.
func (cl *customLoggerWriter[T]) Write(p []byte) (int, error) {
switch cl.level {
case fiberlog.LevelTrace:
cl.loggerInstance.Trace(utils.UnsafeString(p))
case fiberlog.LevelDebug:
cl.loggerInstance.Debug(utils.UnsafeString(p))
case fiberlog.LevelInfo:
cl.loggerInstance.Info(utils.UnsafeString(p))
case fiberlog.LevelWarn:
cl.loggerInstance.Warn(utils.UnsafeString(p))
case fiberlog.LevelError:
cl.loggerInstance.Error(utils.UnsafeString(p))
default:
return 0, nil
}
return len(p), nil
}
// LoggerToWriter is a helper function that returns an io.Writer that writes to a custom logger.
// You can integrate 3rd party loggers such as zerolog, logrus, etc. to logger middleware using this function.
//
// Valid levels: fiberlog.LevelInfo, fiberlog.LevelTrace, fiberlog.LevelWarn, fiberlog.LevelDebug, fiberlog.LevelError
func LoggerToWriter[T any](logger fiberlog.AllLogger[T], level fiberlog.Level) io.Writer {
// Check if customLogger is nil
if logger == nil {
fiberlog.Panic("LoggerToWriter: customLogger must not be nil")
}
// Check if level is valid
if level == fiberlog.LevelFatal || level == fiberlog.LevelPanic {
fiberlog.Panic("LoggerToWriter: invalid level")
}
return &customLoggerWriter[T]{
level: level,
loggerInstance: logger,
}
}
================================================
FILE: middleware/paginate/config.go
================================================
package paginate
import (
"slices"
"github.com/gofiber/fiber/v3"
)
// Config defines the config for the pagination middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
Next func(c fiber.Ctx) bool
// PageKey is the query string key for page number.
//
// Optional. Default: "page"
PageKey string
// LimitKey is the query string key for limit.
//
// Optional. Default: "limit"
LimitKey string
// SortKey is the query string key for sort.
//
// Optional. Default: ""
SortKey string
// DefaultSort is the default sort field.
//
// Optional. Default: "id"
DefaultSort string
// CursorKey is the query string key for cursor-based pagination.
//
// Optional. Default: "cursor"
CursorKey string
// OffsetKey is the query string key for offset.
//
// Optional. Default: "offset"
OffsetKey string
// CursorParam is an optional alias for the cursor query key.
//
// Optional. Default: ""
CursorParam string
// AllowedSorts is the list of allowed sort fields.
//
// Optional. Default: nil
AllowedSorts []string
// DefaultPage is the default page number.
//
// Optional. Default: 1
DefaultPage int
// DefaultLimit is the default items per page.
//
// Optional. Default: 10
DefaultLimit int
// MaxLimit is the maximum items per page.
//
// Optional. Default: 100
MaxLimit int
}
// ConfigDefault is the default config.
var ConfigDefault = Config{
Next: nil,
PageKey: "page",
DefaultPage: 1,
LimitKey: "limit",
DefaultLimit: 10,
MaxLimit: DefaultMaxLimit,
DefaultSort: "id",
OffsetKey: "offset",
CursorKey: "cursor",
}
func configDefault(config ...Config) Config {
if len(config) < 1 {
return ConfigDefault
}
cfg := config[0]
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.PageKey == "" {
cfg.PageKey = ConfigDefault.PageKey
}
if cfg.DefaultLimit < 1 {
cfg.DefaultLimit = ConfigDefault.DefaultLimit
}
if cfg.LimitKey == "" {
cfg.LimitKey = ConfigDefault.LimitKey
}
if cfg.DefaultPage < 1 {
cfg.DefaultPage = ConfigDefault.DefaultPage
}
if cfg.CursorKey == "" {
cfg.CursorKey = ConfigDefault.CursorKey
}
if cfg.DefaultSort == "" {
cfg.DefaultSort = ConfigDefault.DefaultSort
}
if cfg.OffsetKey == "" {
cfg.OffsetKey = ConfigDefault.OffsetKey
}
if cfg.MaxLimit < 1 {
cfg.MaxLimit = ConfigDefault.MaxLimit
}
if cfg.DefaultLimit > cfg.MaxLimit {
cfg.DefaultLimit = cfg.MaxLimit
}
if len(cfg.AllowedSorts) > 0 && !slices.Contains(cfg.AllowedSorts, cfg.DefaultSort) {
cfg.DefaultSort = cfg.AllowedSorts[0]
}
return cfg
}
================================================
FILE: middleware/paginate/page_info.go
================================================
package paginate
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/url"
"github.com/gofiber/utils/v2"
)
// ErrCursorEncode is returned when cursor values cannot be encoded.
var ErrCursorEncode = errors.New("paginate: failed to encode cursor values")
// SortOrder represents sort order.
type SortOrder string
const (
ASC SortOrder = "asc"
DESC SortOrder = "desc"
)
// SortField represents a sort field with direction.
type SortField struct {
Field string `json:"field"`
Order SortOrder `json:"order"`
}
// SortOrderFromString returns a SortOrder from a string (case-insensitive).
func SortOrderFromString(s string) SortOrder {
if utils.EqualFold(s, "desc") {
return DESC
}
return ASC
}
// PageInfo contains pagination information.
type PageInfo struct {
cursorData map[string]any
jsonMarshal utils.JSONMarshal
jsonUnmarshal utils.JSONUnmarshal
Cursor string `json:"cursor,omitempty"`
NextCursor string `json:"next_cursor,omitempty"`
Sort []SortField `json:"sort"`
Page int `json:"page"`
Limit int `json:"limit"`
Offset int `json:"offset"`
HasMore bool `json:"has_more,omitempty"`
}
// NewPageInfo creates a new PageInfo.
func NewPageInfo(page, limit, offset int, sort []SortField) *PageInfo {
return &PageInfo{
Page: page,
Limit: limit,
Offset: offset,
Sort: sort,
}
}
// Start returns the start index based on page/limit or offset.
func (p *PageInfo) Start() int {
if p.Offset > 0 {
return p.Offset
}
if p.Page < 1 {
return 0
}
return (p.Page - 1) * p.Limit
}
// SortBy adds a sort field. Chainable.
func (p *PageInfo) SortBy(field string, order SortOrder) *PageInfo {
p.Sort = append(p.Sort, SortField{Field: field, Order: order})
return p
}
// NextPageURLWithKeys returns the URL for the next page using custom query keys.
func (p *PageInfo) NextPageURLWithKeys(baseURL, pageKey, limitKey string) string {
return buildPaginationURL(baseURL, pageKey, utils.FormatInt(int64(p.Page+1)), limitKey, utils.FormatInt(int64(p.Limit)))
}
// NextPageURL returns the URL for the next page.
func (p *PageInfo) NextPageURL(baseURL string) string {
return p.NextPageURLWithKeys(baseURL, "page", "limit")
}
// PreviousPageURLWithKeys returns the URL for the previous page using custom query keys.
// Returns empty string if on page 1.
func (p *PageInfo) PreviousPageURLWithKeys(baseURL, pageKey, limitKey string) string {
if p.Page > 1 {
return buildPaginationURL(baseURL, pageKey, utils.FormatInt(int64(p.Page-1)), limitKey, utils.FormatInt(int64(p.Limit)))
}
return ""
}
// PreviousPageURL returns the URL for the previous page.
// Returns empty string if on page 1.
func (p *PageInfo) PreviousPageURL(baseURL string) string {
return p.PreviousPageURLWithKeys(baseURL, "page", "limit")
}
// NextCursorURLWithKeys returns the URL for the next cursor page using custom query keys.
// Returns empty string if HasMore is false.
func (p *PageInfo) NextCursorURLWithKeys(baseURL, cursorKey, limitKey string) string {
if !p.HasMore {
return ""
}
return buildPaginationURL(baseURL, cursorKey, p.NextCursor, limitKey, utils.FormatInt(int64(p.Limit)))
}
// NextCursorURL returns the URL for the next cursor page.
// Returns empty string if HasMore is false.
func (p *PageInfo) NextCursorURL(baseURL string) string {
return p.NextCursorURLWithKeys(baseURL, "cursor", "limit")
}
// buildPaginationURL parses baseURL and sets/replaces two query parameters,
// preserving any existing query string values.
func buildPaginationURL(baseURL, pageParam, pageValue, limitParam, limitValue string) string {
u, err := url.Parse(baseURL)
if err != nil {
return baseURL
}
q := u.Query()
q.Set(pageParam, pageValue)
q.Set(limitParam, limitValue)
u.RawQuery = q.Encode()
return u.String()
}
// CursorValues returns the decoded cursor key-value map.
// If the cursor was parsed by the middleware, the pre-parsed data is returned.
// Otherwise it decodes the opaque cursor string.
// Returns nil if cursor is empty or invalid.
func (p *PageInfo) CursorValues() map[string]any {
if p.cursorData != nil {
return p.cursorData
}
if p.Cursor == "" {
return nil
}
if len(p.Cursor) > maxCursorLen {
return nil
}
data, err := base64.RawURLEncoding.DecodeString(p.Cursor)
if err != nil {
return nil
}
var values map[string]any
unmarshal := p.jsonUnmarshal
if unmarshal == nil {
unmarshal = json.Unmarshal
}
if err := unmarshal(data, &values); err != nil {
return nil
}
p.cursorData = values
return values
}
// SetNextCursor encodes a key-value map into an opaque cursor token
// and sets both NextCursor and HasMore on the PageInfo.
func (p *PageInfo) SetNextCursor(values map[string]any) error {
marshal := p.jsonMarshal
if marshal == nil {
marshal = json.Marshal
}
data, err := marshal(values)
if err != nil {
return fmt.Errorf("%w: %w", ErrCursorEncode, err)
}
encoded := base64.RawURLEncoding.EncodeToString(data)
if len(encoded) > maxCursorLen {
return fmt.Errorf("%w: cursor token exceeds maximum length (%d)", ErrCursorEncode, maxCursorLen)
}
p.NextCursor = encoded
p.HasMore = true
return nil
}
================================================
FILE: middleware/paginate/paginate.go
================================================
package paginate
import (
"encoding/base64"
"slices"
"strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
)
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
const (
pageInfoKey contextKey = iota
)
// DefaultMaxLimit is the default maximum limit allowed.
const DefaultMaxLimit = 100
// maxCursorLen is the maximum allowed cursor string length.
const maxCursorLen = 2048
// New creates a new pagination middleware handler.
func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)
return func(c fiber.Ctx) error {
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
appCfg := c.App().Config()
limit := fiber.Query(c, cfg.LimitKey, cfg.DefaultLimit)
if limit < 1 {
limit = cfg.DefaultLimit
}
if limit > cfg.MaxLimit {
limit = cfg.MaxLimit
}
sorts := parseSortQuery(c.Query(cfg.SortKey), cfg.AllowedSorts, cfg.DefaultSort)
cursorRaw := c.Query(cfg.CursorKey)
if cursorRaw == "" && cfg.CursorParam != "" {
cursorRaw = c.Query(cfg.CursorParam)
}
if cursorRaw != "" {
if len(cursorRaw) > maxCursorLen {
return fiber.NewError(fiber.StatusBadRequest, "cursor too long")
}
data, err := base64.RawURLEncoding.DecodeString(cursorRaw)
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, "invalid cursor")
}
var obj map[string]any
if err := appCfg.JSONDecoder(data, &obj); err != nil {
return fiber.NewError(fiber.StatusBadRequest, "invalid cursor")
}
pageInfo := &PageInfo{
Limit: limit,
Sort: sorts,
Cursor: cursorRaw,
cursorData: obj,
jsonMarshal: appCfg.JSONEncoder,
jsonUnmarshal: appCfg.JSONDecoder,
}
fiber.StoreInContext(c, pageInfoKey, pageInfo)
return c.Next()
}
page := max(fiber.Query(c, cfg.PageKey, cfg.DefaultPage), 1)
offset := max(fiber.Query(c, cfg.OffsetKey, 0), 0)
pageInfo := NewPageInfo(page, limit, offset, sorts)
pageInfo.jsonMarshal = appCfg.JSONEncoder
pageInfo.jsonUnmarshal = appCfg.JSONDecoder
fiber.StoreInContext(c, pageInfoKey, pageInfo)
return c.Next()
}
}
// FromContext returns the PageInfo from the request context.
// It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context.
// Returns nil and false if no PageInfo is stored.
func FromContext(ctx any) (*PageInfo, bool) {
return fiber.ValueFromContext[*PageInfo](ctx, pageInfoKey)
}
func parseSortQuery(query string, allowedSorts []string, defaultSort string) []SortField {
if query == "" {
return []SortField{{Field: defaultSort, Order: ASC}}
}
fields := strings.Split(query, ",")
sortFields := make([]SortField, 0, len(fields))
for _, field := range fields {
field = utils.TrimSpace(field)
if field == "" {
continue
}
order := ASC
if strings.HasPrefix(field, "-") {
order = DESC
field = utils.TrimSpace(field[1:])
}
if field == "" {
continue
}
if len(allowedSorts) == 0 || slices.Contains(allowedSorts, field) {
sortFields = append(sortFields, SortField{Field: field, Order: order})
}
}
if len(sortFields) == 0 {
return []SortField{{Field: defaultSort, Order: ASC}}
}
return sortFields
}
================================================
FILE: middleware/paginate/paginate_test.go
================================================
package paginate
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
type paginateResponse struct {
NextPageURL string `json:"next_page_url"`
PreviousPageURL string `json:"prev_page_url"`
Sort []SortField `json:"sort"`
Page int `json:"page"`
Limit int `json:"limit"`
Offset int `json:"offset"`
Start int `json:"start"`
}
type cursorResponse struct {
Cursor string `json:"cursor"`
NextCursor string `json:"next_cursor"`
Sort []SortField `json:"sort"`
Limit int `json:"limit"`
HasMore bool `json:"has_more"`
}
// --- Config tests ---
func Test_ConfigDefault(t *testing.T) {
t.Parallel()
cfg := configDefault()
require.Equal(t, "page", cfg.PageKey)
require.Equal(t, 1, cfg.DefaultPage)
require.Equal(t, "limit", cfg.LimitKey)
require.Equal(t, 10, cfg.DefaultLimit)
require.Equal(t, DefaultMaxLimit, cfg.MaxLimit)
}
func Test_ConfigOverride(t *testing.T) {
t.Parallel()
cfg := configDefault(Config{
PageKey: "p",
LimitKey: "l",
DefaultPage: 5,
DefaultLimit: 50,
})
require.Equal(t, "p", cfg.PageKey)
require.Equal(t, "l", cfg.LimitKey)
require.Equal(t, 5, cfg.DefaultPage)
require.Equal(t, 50, cfg.DefaultLimit)
}
func Test_ConfigDefaultCursorKey(t *testing.T) {
t.Parallel()
cfg := configDefault()
require.Equal(t, "cursor", cfg.CursorKey)
}
func Test_ConfigOverrideCursorKey(t *testing.T) {
t.Parallel()
cfg := configDefault(Config{
CursorKey: "after",
CursorParam: "starting_after",
})
require.Equal(t, "after", cfg.CursorKey)
require.Equal(t, "starting_after", cfg.CursorParam)
}
func Test_ConfigNegativeDefaults(t *testing.T) {
t.Parallel()
cfg := configDefault(Config{
DefaultPage: -1,
DefaultLimit: -1,
})
require.Equal(t, 1, cfg.DefaultPage)
require.Equal(t, 10, cfg.DefaultLimit)
}
// --- PageInfo tests ---
func Test_SortOrderFromString(t *testing.T) {
t.Parallel()
tests := []struct {
input string
expected SortOrder
}{
{"asc", ASC},
{"desc", DESC},
{"DESC", DESC},
{"Desc", DESC},
{"invalid", ASC},
{"", ASC},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.expected, SortOrderFromString(tt.input))
})
}
}
func Test_PageInfoStart(t *testing.T) {
t.Parallel()
tests := []struct {
name string
pageInfo PageInfo
expected int
}{
{"Page 1, limit 10", PageInfo{Page: 1, Limit: 10}, 0},
{"Page 2, limit 10", PageInfo{Page: 2, Limit: 10}, 10},
{"Page 3, limit 20", PageInfo{Page: 3, Limit: 20}, 40},
{"With offset", PageInfo{Page: 2, Limit: 10, Offset: 25}, 25},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.expected, tt.pageInfo.Start())
})
}
}
func Test_PageInfoSortBy(t *testing.T) {
t.Parallel()
p := NewPageInfo(1, 10, 0, nil)
p.SortBy("name", ASC).SortBy("date", DESC)
require.Len(t, p.Sort, 2)
require.Equal(t, "name", p.Sort[0].Field)
require.Equal(t, ASC, p.Sort[0].Order)
require.Equal(t, "date", p.Sort[1].Field)
require.Equal(t, DESC, p.Sort[1].Order)
}
func Test_PageInfoNextPageURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
baseURL string
expected string
pageInfo PageInfo
}{
{
name: "Middle page",
baseURL: "https://example.com/users",
expected: "https://example.com/users?limit=10&page=3",
pageInfo: PageInfo{Page: 2, Limit: 10},
},
{
name: "First page",
baseURL: "https://example.com/users",
expected: "https://example.com/users?limit=20&page=2",
pageInfo: PageInfo{Page: 1, Limit: 20},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.expected, tt.pageInfo.NextPageURL(tt.baseURL))
})
}
}
func Test_PageInfoPreviousPageURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
baseURL string
expected string
pageInfo PageInfo
}{
{
name: "Middle page",
baseURL: "https://example.com/users",
expected: "https://example.com/users?limit=10&page=1",
pageInfo: PageInfo{Page: 2, Limit: 10},
},
{
name: "First page returns empty",
baseURL: "https://example.com/users",
expected: "",
pageInfo: PageInfo{Page: 1, Limit: 20},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.expected, tt.pageInfo.PreviousPageURL(tt.baseURL))
})
}
}
func Test_PageInfoStartCursorMode(t *testing.T) {
t.Parallel()
// In cursor mode, Page is 0 (not set). Start() should return 0, not negative.
p := &PageInfo{Page: 0, Limit: 20}
require.Equal(t, 0, p.Start())
}
func Test_PageInfoNextPageURLWithExistingQueryParams(t *testing.T) {
t.Parallel()
p := PageInfo{Page: 2, Limit: 10}
result := p.NextPageURL("https://example.com/users?filter=active")
require.Contains(t, result, "filter=active")
require.Contains(t, result, "page=3")
require.Contains(t, result, "limit=10")
}
func Test_PageInfoPreviousPageURLWithExistingQueryParams(t *testing.T) {
t.Parallel()
p := PageInfo{Page: 3, Limit: 10}
result := p.PreviousPageURL("https://example.com/users?filter=active")
require.Contains(t, result, "filter=active")
require.Contains(t, result, "page=2")
require.Contains(t, result, "limit=10")
}
func Test_PageInfoCursorFields(t *testing.T) {
t.Parallel()
p := &PageInfo{
Cursor: "abc123",
HasMore: true,
NextCursor: "def456",
}
require.Equal(t, "abc123", p.Cursor)
require.True(t, p.HasMore)
require.Equal(t, "def456", p.NextCursor)
}
func Test_CursorValuesRoundTrip(t *testing.T) {
t.Parallel()
original := map[string]any{
"id": float64(42),
"created_at": "2026-01-01T00:00:00Z",
}
p := &PageInfo{}
require.NoError(t, p.SetNextCursor(original))
require.True(t, p.HasMore)
require.NotEmpty(t, p.NextCursor)
p2 := &PageInfo{Cursor: p.NextCursor}
decoded := p2.CursorValues()
require.NotNil(t, decoded)
require.InEpsilon(t, float64(42), decoded["id"], 0)
require.Equal(t, "2026-01-01T00:00:00Z", decoded["created_at"])
}
func Test_CursorValuesEmptyCursor(t *testing.T) {
t.Parallel()
p := &PageInfo{Cursor: ""}
require.Nil(t, p.CursorValues())
}
func Test_CursorValuesInvalidBase64(t *testing.T) {
t.Parallel()
p := &PageInfo{Cursor: "not-valid-base64!!!"}
require.Nil(t, p.CursorValues())
}
func Test_CursorValuesInvalidJSON(t *testing.T) {
t.Parallel()
p := &PageInfo{Cursor: "bm90LWpzb24"}
require.Nil(t, p.CursorValues())
}
func Test_NextCursorURL(t *testing.T) {
t.Parallel()
t.Run("with HasMore", func(t *testing.T) {
t.Parallel()
p := &PageInfo{Limit: 20}
require.NoError(t, p.SetNextCursor(map[string]any{"id": float64(42)}))
url := p.NextCursorURL("https://example.com/users")
expected := fmt.Sprintf("https://example.com/users?cursor=%s&limit=20", p.NextCursor)
require.Equal(t, expected, url)
})
t.Run("without HasMore", func(t *testing.T) {
t.Parallel()
p := &PageInfo{Limit: 20}
require.Empty(t, p.NextCursorURL("https://example.com/users"))
})
}
func Test_SetNextCursorSetsFields(t *testing.T) {
t.Parallel()
p := &PageInfo{Limit: 10}
require.NoError(t, p.SetNextCursor(map[string]any{"id": float64(1)}))
require.True(t, p.HasMore)
require.NotEmpty(t, p.NextCursor)
}
// --- Middleware handler tests ---
func Test_PaginateWithQueries(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
DefaultSort: "id",
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(paginateResponse{
Page: pageInfo.Page,
Limit: pageInfo.Limit,
Offset: pageInfo.Offset,
Start: pageInfo.Start(),
Sort: pageInfo.Sort,
NextPageURL: pageInfo.NextPageURL(c.BaseURL()),
PreviousPageURL: pageInfo.PreviousPageURL(c.BaseURL()),
})
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?page=2&limit=20", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
require.Equal(t, fiber.StatusOK, resp.StatusCode)
var body paginateResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, 2, body.Page)
require.Equal(t, 20, body.Limit)
require.Equal(t, 0, body.Offset)
require.Equal(t, 20, body.Start)
require.Equal(t, "http://example.com?limit=20&page=3", body.NextPageURL)
require.Equal(t, "http://example.com?limit=20&page=1", body.PreviousPageURL)
require.Equal(t, []SortField{{Field: "id", Order: ASC}}, body.Sort)
}
func Test_PaginateWithOffset(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(paginateResponse{
Page: pageInfo.Page,
Limit: pageInfo.Limit,
Offset: pageInfo.Offset,
Start: pageInfo.Start(),
Sort: pageInfo.Sort,
})
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?offset=20&limit=20", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
require.Equal(t, fiber.StatusOK, resp.StatusCode)
var body paginateResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, 1, body.Page)
require.Equal(t, 20, body.Limit)
require.Equal(t, 20, body.Offset)
require.Equal(t, 20, body.Start)
}
func Test_PaginateCheckDefaultsWhenNoQueries(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(paginateResponse{
Page: pageInfo.Page,
Limit: pageInfo.Limit,
Offset: pageInfo.Offset,
Start: pageInfo.Start(),
Sort: pageInfo.Sort,
})
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
require.Equal(t, fiber.StatusOK, resp.StatusCode)
var body paginateResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, 1, body.Page)
require.Equal(t, 10, body.Limit)
require.Equal(t, 0, body.Offset)
require.Equal(t, 0, body.Start)
require.Equal(t, []SortField{{Field: "id", Order: ASC}}, body.Sort)
}
func Test_PaginateCheckDefaultsWhenNoPage(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(paginateResponse{
Page: pageInfo.Page,
Limit: pageInfo.Limit,
Start: pageInfo.Start(),
})
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?limit=20", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
var body paginateResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, 1, body.Page)
require.Equal(t, 20, body.Limit)
require.Equal(t, 0, body.Start)
}
func Test_PaginateCheckDefaultsWhenNoLimit(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(paginateResponse{
Page: pageInfo.Page,
Limit: pageInfo.Limit,
Start: pageInfo.Start(),
})
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?page=2", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
var body paginateResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, 2, body.Page)
require.Equal(t, 10, body.Limit)
require.Equal(t, 10, body.Start)
}
func Test_PaginateConfigDefaultPageDefaultLimit(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
DefaultPage: 100,
DefaultLimit: DefaultMaxLimit,
DefaultSort: "name",
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(paginateResponse{
Page: pageInfo.Page,
Limit: pageInfo.Limit,
Start: pageInfo.Start(),
Sort: pageInfo.Sort,
})
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
var body paginateResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, 100, body.Page)
require.Equal(t, DefaultMaxLimit, body.Limit)
require.Equal(t, 9900, body.Start)
require.Equal(t, []SortField{{Field: "name", Order: ASC}}, body.Sort)
}
func Test_PaginateConfigPageKeyLimitKey(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
PageKey: "site",
LimitKey: "size",
DefaultSort: "id",
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(paginateResponse{
Page: pageInfo.Page,
Limit: pageInfo.Limit,
Start: pageInfo.Start(),
Sort: pageInfo.Sort,
})
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?site=2&size=5", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
var body paginateResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, 2, body.Page)
require.Equal(t, 5, body.Limit)
require.Equal(t, 5, body.Start)
}
func Test_PaginateNegativeDefaultPageDefaultLimitValues(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
DefaultPage: -1,
DefaultLimit: -1,
DefaultSort: "id",
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(paginateResponse{
Page: pageInfo.Page,
Limit: pageInfo.Limit,
Start: pageInfo.Start(),
})
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
var body paginateResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
require.Equal(t, 1, body.Page)
require.Equal(t, 10, body.Limit)
require.Equal(t, 0, body.Start)
}
func Test_PaginateFromContextWithoutNew(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/", func(c fiber.Ctx) error {
_, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(nil)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
}
func Test_PaginateNextSkip(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
app.Get("/", func(c fiber.Ctx) error {
_, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(nil)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
require.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
}
func Test_PaginateEdgeCases(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
url string
expectedPage int
expectedLimit int
}{
{"Negative page", "/?page=-1", 1, 10},
{"Page zero", "/?page=0", 1, 10},
{"Negative limit", "/?limit=-10", 1, 10},
{"Limit zero", "/?limit=0", 1, 10},
{"Limit exceeds max", "/?limit=200", 1, DefaultMaxLimit},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
DefaultSort: "id",
DefaultLimit: 10,
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(pageInfo)
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, tc.url, http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
require.Equal(t, 200, resp.StatusCode)
var result PageInfo
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, tc.expectedPage, result.Page)
require.Equal(t, tc.expectedLimit, result.Limit)
})
}
}
func Test_PaginateWithMultipleSorting(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
url string
expectedSort []SortField
}{
{"Default Sort", "/", []SortField{{Field: "id", Order: ASC}}},
{"Single Sort", "/?sort=name", []SortField{{Field: "name", Order: ASC}}},
{"Multiple Sort", "/?sort=name,-date", []SortField{{Field: "name", Order: ASC}, {Field: "date", Order: DESC}}},
{"Invalid Sort", "/?sort=invalid", []SortField{{Field: "id", Order: ASC}}},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
SortKey: "sort",
DefaultSort: "id",
AllowedSorts: []string{"id", "name", "date"},
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(paginateResponse{
Sort: pageInfo.Sort,
})
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tc.url, http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
var result paginateResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, tc.expectedSort, result.Sort)
})
}
}
func Test_ParseSortQuery(t *testing.T) {
t.Parallel()
tests := []struct {
name string
query string
allowedSorts []string
defaultSort string
expected []SortField
}{
{
"Empty query",
"",
[]string{"id", "name", "date"},
"id",
[]SortField{{Field: "id", Order: ASC}},
},
{
"Single allowed field",
"name",
[]string{"id", "name", "date"},
"id",
[]SortField{{Field: "name", Order: ASC}},
},
{
"Multiple fields with mixed order",
"name,-date,id",
[]string{"id", "name", "date"},
"id",
[]SortField{
{Field: "name", Order: ASC},
{Field: "date", Order: DESC},
{Field: "id", Order: ASC},
},
},
{
"Disallowed field",
"email,name",
[]string{"id", "name", "date"},
"id",
[]SortField{{Field: "name", Order: ASC}},
},
{
"All disallowed fields",
"email,phone",
[]string{"id", "name", "date"},
"id",
[]SortField{{Field: "id", Order: ASC}},
},
{
"Nil AllowedSorts allows all fields",
"email,-phone",
nil,
"id",
[]SortField{
{Field: "email", Order: ASC},
{Field: "phone", Order: DESC},
},
},
{
"Bare dash is skipped",
"-",
nil,
"id",
[]SortField{{Field: "id", Order: ASC}},
},
{
"Dash in comma list is skipped",
"name,-,email",
nil,
"id",
[]SortField{
{Field: "name", Order: ASC},
{Field: "email", Order: ASC},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := parseSortQuery(tt.query, tt.allowedSorts, tt.defaultSort)
require.Equal(t, tt.expected, result)
})
}
}
// --- Cursor tests ---
func Test_PaginateWithCursor(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
DefaultSort: "id",
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(cursorResponse{
Cursor: pageInfo.Cursor,
Limit: pageInfo.Limit,
Sort: pageInfo.Sort,
})
})
cursorJSON := `{"id":42}`
cursor := base64.RawURLEncoding.EncodeToString([]byte(cursorJSON))
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?cursor="+cursor+"&limit=20", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
require.Equal(t, 200, resp.StatusCode)
var result cursorResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, cursor, result.Cursor)
require.Equal(t, 20, result.Limit)
}
func Test_PaginateCursorPriorityOverPage(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(pageInfo)
})
cursorJSON := `{"id":42}`
cursor := base64.RawURLEncoding.EncodeToString([]byte(cursorJSON))
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?cursor="+cursor+"&page=5&limit=10", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
var result PageInfo
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, cursor, result.Cursor)
require.Equal(t, 0, result.Page)
}
func Test_PaginateEmptyCursorIsFirstPage(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(pageInfo)
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?cursor=&limit=10", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
require.Equal(t, 200, resp.StatusCode)
var result PageInfo
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Empty(t, result.Cursor)
}
func Test_PaginateInvalidCursorReturns400(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
cursor string
}{
{"Invalid base64", "not-valid!!!"},
{"Valid base64 but invalid JSON", base64.RawURLEncoding.EncodeToString([]byte("not-json"))},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
pageInfo, _ := FromContext(c)
return c.JSON(pageInfo)
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?cursor="+tc.cursor, http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
require.Equal(t, 400, resp.StatusCode)
})
}
}
func Test_PaginateCursorWithSort(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
SortKey: "sort",
DefaultSort: "id",
AllowedSorts: []string{"id", "name"},
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(cursorResponse{
Cursor: pageInfo.Cursor,
Sort: pageInfo.Sort,
})
})
cursorJSON := `{"id":42}`
cursor := base64.RawURLEncoding.EncodeToString([]byte(cursorJSON))
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?cursor="+cursor+"&sort=name,-id", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
var result cursorResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, []SortField{{Field: "name", Order: ASC}, {Field: "id", Order: DESC}}, result.Sort)
}
func Test_PaginateCursorWithCustomKey(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
CursorKey: "after",
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(cursorResponse{
Cursor: pageInfo.Cursor,
Limit: pageInfo.Limit,
})
})
cursorJSON := `{"id":1}`
cursor := base64.RawURLEncoding.EncodeToString([]byte(cursorJSON))
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?after="+cursor, http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
require.Equal(t, 200, resp.StatusCode)
var result cursorResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, cursor, result.Cursor)
}
func Test_PaginateCursorWithParamAlias(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
CursorParam: "starting_after",
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(cursorResponse{
Cursor: pageInfo.Cursor,
})
})
cursorJSON := `{"id":1}`
cursor := base64.RawURLEncoding.EncodeToString([]byte(cursorJSON))
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?starting_after="+cursor, http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
require.Equal(t, 200, resp.StatusCode)
var result cursorResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, cursor, result.Cursor)
}
func Test_PaginateNoCursorFallsBackToPageMode(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
DefaultSort: "id",
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(paginateResponse{
Page: pageInfo.Page,
Limit: pageInfo.Limit,
Start: pageInfo.Start(),
})
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?page=3&limit=15", http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
var result paginateResponse
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, 3, result.Page)
require.Equal(t, 15, result.Limit)
require.Equal(t, 30, result.Start)
}
func Test_NextPageURLWithKeys(t *testing.T) {
t.Parallel()
p := PageInfo{Page: 2, Limit: 10}
result := p.NextPageURLWithKeys("https://example.com/users", "p", "per_page")
require.Equal(t, "https://example.com/users?p=3&per_page=10", result)
}
func Test_PreviousPageURLWithKeys(t *testing.T) {
t.Parallel()
t.Run("has previous", func(t *testing.T) {
t.Parallel()
p := PageInfo{Page: 3, Limit: 15}
result := p.PreviousPageURLWithKeys("https://example.com/items", "p", "size")
require.Equal(t, "https://example.com/items?p=2&size=15", result)
})
t.Run("first page returns empty", func(t *testing.T) {
t.Parallel()
p := PageInfo{Page: 1, Limit: 15}
result := p.PreviousPageURLWithKeys("https://example.com/items", "p", "size")
require.Empty(t, result)
})
}
func Test_NextCursorURLWithKeys(t *testing.T) {
t.Parallel()
t.Run("has more", func(t *testing.T) {
t.Parallel()
p := &PageInfo{Limit: 20}
require.NoError(t, p.SetNextCursor(map[string]any{"id": float64(42)}))
result := p.NextCursorURLWithKeys("https://example.com/users", "after", "per_page")
expected := fmt.Sprintf("https://example.com/users?after=%s&per_page=20", p.NextCursor)
require.Equal(t, expected, result)
})
t.Run("no more", func(t *testing.T) {
t.Parallel()
p := &PageInfo{Limit: 20}
result := p.NextCursorURLWithKeys("https://example.com/users", "after", "per_page")
require.Empty(t, result)
})
}
func Test_PaginateCustomMaxLimit(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
url string
expectedLimit int
}{
{"Limit within custom max", "/?limit=40", 40},
{"Limit exceeds custom max", "/?limit=200", 50},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
DefaultSort: "id",
DefaultLimit: 10,
MaxLimit: 50,
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, ok := FromContext(c)
if !ok {
return fiber.ErrBadRequest
}
return c.JSON(pageInfo)
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, tc.url, http.NoBody))
require.NoError(t, err)
defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests
require.Equal(t, http.StatusOK, resp.StatusCode)
var result PageInfo
require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
require.Equal(t, tc.expectedLimit, result.Limit)
})
}
}
func Test_ConfigDefaultMaxLimitNormalization(t *testing.T) {
t.Parallel()
cfg := configDefault(Config{MaxLimit: 0})
require.Equal(t, DefaultMaxLimit, cfg.MaxLimit)
cfg2 := configDefault(Config{MaxLimit: 50})
require.Equal(t, 50, cfg2.MaxLimit)
}
// --- Benchmarks ---
func Benchmark_PaginateMiddleware(b *testing.B) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
pageInfo, _ := FromContext(c)
return c.JSON(pageInfo)
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest(http.MethodGet, "/?page=2&limit=20&sort=name,-date", http.NoBody)
resp, err := app.Test(req, fiber.TestConfig{Timeout: 0})
if err != nil {
b.Fatal(err)
}
resp.Body.Close() //nolint:errcheck // close error not relevant in tests
}
}
func Benchmark_PaginateMiddlewareWithCustomConfig(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
PageKey: "p",
LimitKey: "l",
SortKey: "s",
DefaultPage: 1,
DefaultLimit: 30,
DefaultSort: "id",
AllowedSorts: []string{"id", "name", "date"},
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, _ := FromContext(c)
return c.JSON(pageInfo)
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest(http.MethodGet, "/?p=3&l=25&s=name,-id", http.NoBody)
resp, err := app.Test(req, fiber.TestConfig{Timeout: 0})
if err != nil {
b.Fatal(err)
}
resp.Body.Close() //nolint:errcheck // close error not relevant in tests
}
}
func Benchmark_PaginateCursorMiddleware(b *testing.B) {
app := fiber.New()
app.Use(New(Config{
SortKey: "sort",
DefaultSort: "id",
AllowedSorts: []string{"id", "name", "date"},
}))
app.Get("/", func(c fiber.Ctx) error {
pageInfo, _ := FromContext(c)
return c.JSON(pageInfo)
})
cursorJSON := `{"id":42,"created_at":"2026-01-01T00:00:00Z"}`
cursor := base64.RawURLEncoding.EncodeToString([]byte(cursorJSON))
b.ResetTimer()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest(http.MethodGet, "/?cursor="+cursor+"&limit=20&sort=name,-id", http.NoBody)
resp, err := app.Test(req, fiber.TestConfig{Timeout: 0})
if err != nil {
b.Fatal(err)
}
resp.Body.Close() //nolint:errcheck // close error not relevant in tests
}
}
================================================
FILE: middleware/pprof/config.go
================================================
package pprof
import (
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// Prefix defines a URL prefix added before "/debug/pprof".
// Note that it should start with (but not end with) a slash.
// Example: "/federated-fiber"
//
// Optional. Default: ""
Prefix string
}
var ConfigDefault = Config{
Next: nil,
Prefix: "",
}
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if cfg.Prefix == "" {
cfg.Prefix = ConfigDefault.Prefix
}
return cfg
}
================================================
FILE: middleware/pprof/pprof.go
================================================
package pprof
import (
"net/http/pprof"
"strings"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp/fasthttpadaptor"
"github.com/gofiber/fiber/v3"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Set pprof adaptors
var (
pprofIndex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index)
pprofCmdline = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Cmdline)
pprofProfile = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Profile)
pprofSymbol = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Symbol)
pprofTrace = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Trace)
pprofAllocs = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("allocs").ServeHTTP)
pprofBlock = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("block").ServeHTTP)
pprofGoroutine = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("goroutine").ServeHTTP)
pprofHeap = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("heap").ServeHTTP)
pprofMutex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("mutex").ServeHTTP)
pprofThreadcreate = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("threadcreate").ServeHTTP)
)
// Construct actual prefix
prefix := cfg.Prefix + "/debug/pprof"
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
path := c.Path()
// We are only interested in /debug/pprof routes
path, found := strings.CutPrefix(path, prefix)
if !found {
return c.Next()
}
// Switch on trimmed path against constant strings
switch path {
case "/":
pprofIndex(c.RequestCtx())
case "/cmdline":
pprofCmdline(c.RequestCtx())
case "/profile":
pprofProfile(c.RequestCtx())
case "/symbol":
pprofSymbol(c.RequestCtx())
case "/trace":
pprofTrace(c.RequestCtx())
case "/allocs":
pprofAllocs(c.RequestCtx())
case "/block":
pprofBlock(c.RequestCtx())
case "/goroutine":
pprofGoroutine(c.RequestCtx())
case "/heap":
pprofHeap(c.RequestCtx())
case "/mutex":
pprofMutex(c.RequestCtx())
case "/threadcreate":
pprofThreadcreate(c.RequestCtx())
default:
// pprof index only works with trailing slash
if strings.HasSuffix(path, "/") {
path = utils.TrimRight(path, '/')
} else {
path = prefix + "/"
}
return c.Redirect().To(path)
}
return nil
}
}
================================================
FILE: middleware/pprof/pprof_test.go
================================================
package pprof
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
var testConfig = fiber.TestConfig{
Timeout: 5 * time.Second,
FailOnTimeout: true,
}
func Test_Non_Pprof_Path(t *testing.T) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "escaped", string(b))
}
func Test_Non_Pprof_Path_WithPrefix(t *testing.T) {
app := fiber.New()
app.Use(New(Config{Prefix: "/federated-fiber"}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "escaped", string(b))
}
func Test_Pprof_Index(t *testing.T) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/pprof/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.True(t, bytes.Contains(b, []byte("/debug/pprof/")))
}
func Test_Pprof_Index_WithPrefix(t *testing.T) {
app := fiber.New()
app.Use(New(Config{Prefix: "/federated-fiber"}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/federated-fiber/debug/pprof/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(b), "/debug/pprof/")
}
func Test_Pprof_Subs(t *testing.T) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("escaped")
})
subs := []string{
"cmdline", "profile", "symbol", "trace", "allocs", "block",
"goroutine", "heap", "mutex", "threadcreate",
}
for _, sub := range subs {
t.Run(sub, func(t *testing.T) {
target := "/debug/pprof/" + sub
if sub == "profile" {
target += "?seconds=1"
}
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, target, http.NoBody), testConfig)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
})
}
}
func Test_Pprof_Subs_WithPrefix(t *testing.T) {
app := fiber.New()
app.Use(New(Config{Prefix: "/federated-fiber"}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("escaped")
})
subs := []string{
"cmdline", "profile", "symbol", "trace", "allocs", "block",
"goroutine", "heap", "mutex", "threadcreate",
}
for _, sub := range subs {
t.Run(sub, func(t *testing.T) {
target := "/federated-fiber/debug/pprof/" + sub
if sub == "profile" {
target += "?seconds=1"
}
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, target, http.NoBody), testConfig)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
})
}
}
func Test_Pprof_Other(t *testing.T) {
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/pprof/303", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusSeeOther, resp.StatusCode)
}
func Test_Pprof_Other_WithPrefix(t *testing.T) {
app := fiber.New()
app.Use(New(Config{Prefix: "/federated-fiber"}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("escaped")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/federated-fiber/debug/pprof/303", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusSeeOther, resp.StatusCode)
}
// go test -run Test_Pprof_Next
func Test_Pprof_Next(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/pprof/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 404, resp.StatusCode)
}
// go test -run Test_Pprof_Next_WithPrefix
func Test_Pprof_Next_WithPrefix(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
Prefix: "/federated-fiber",
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/federated-fiber/debug/pprof/", http.NoBody))
require.NoError(t, err)
require.Equal(t, 404, resp.StatusCode)
}
================================================
FILE: middleware/proxy/config.go
================================================
package proxy
import (
"crypto/tls"
"time"
"github.com/gofiber/fiber/v3"
"github.com/valyala/fasthttp"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// ModifyRequest allows you to alter the request
//
// Optional. Default: nil
ModifyRequest fiber.Handler
// ModifyResponse allows you to alter the response
//
// Optional. Default: nil
ModifyResponse fiber.Handler
// tls config for the http client.
TLSConfig *tls.Config
// Client is custom client when client config is complex.
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize, TLSConfig
// and DialDualStack will not be used if the client are set.
Client *fasthttp.LBClient
// Servers defines a list of :// HTTP servers,
//
// which are used in a round-robin manner.
// i.e.: "https://foobar.com, http://www.foobar.com"
//
// Required
Servers []string
// Timeout is the request timeout used when calling the proxy client
//
// Optional. Default: 1 second
Timeout time.Duration
// Per-connection buffer size for requests' reading.
// This also limits the maximum header size.
// Increase this buffer if your clients send multi-KB RequestURIs
// and/or multi-KB headers (for example, BIG cookies).
ReadBufferSize int
// Per-connection buffer size for responses' writing.
WriteBufferSize int
// KeepConnectionHeader keeps the "Connection" header when set to true.
//
// Optional. Default: false
KeepConnectionHeader bool
// Attempt to connect to both ipv4 and ipv6 host addresses if set to true.
//
// By default client connects only to ipv4 addresses, since unfortunately ipv6
// remains broken in many networks worldwide :)
//
// Optional. Default: false
DialDualStack bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
ModifyRequest: nil,
ModifyResponse: nil,
Timeout: fasthttp.DefaultLBClientTimeout,
KeepConnectionHeader: false,
}
// configDefault function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Timeout <= 0 {
cfg.Timeout = ConfigDefault.Timeout
}
// Set default values
if len(cfg.Servers) == 0 && cfg.Client == nil {
panic("Servers cannot be empty")
}
return cfg
}
================================================
FILE: middleware/proxy/proxy.go
================================================
package proxy
import (
"bytes"
"errors"
"net/url"
"strings"
"sync"
"time"
"github.com/gofiber/utils/v2"
"github.com/gofiber/fiber/v3"
"github.com/valyala/fasthttp"
)
// Balancer creates a load balancer among multiple upstream servers
func Balancer(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Load balanced client
lbc := &fasthttp.LBClient{}
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TLSConfig
// will not be used if the client are set.
if cfg.Client == nil {
// Set timeout
lbc.Timeout = cfg.Timeout
// Scheme must be provided, falls back to http
for _, server := range cfg.Servers {
if !strings.HasPrefix(server, "http") {
server = "http://" + server
}
u, err := url.Parse(server)
if err != nil {
panic(err)
}
client := &fasthttp.HostClient{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Addr: u.Host,
ReadBufferSize: cfg.ReadBufferSize,
WriteBufferSize: cfg.WriteBufferSize,
TLSConfig: cfg.TLSConfig,
DialDualStack: cfg.DialDualStack,
}
lbc.Clients = append(lbc.Clients, client)
}
} else {
// Set custom client
lbc = cfg.Client
}
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Set request and response
req := c.Request()
res := c.Response()
if !cfg.KeepConnectionHeader {
// Don't proxy "Connection" header
req.Header.Del(fiber.HeaderConnection)
}
// Modify request
if cfg.ModifyRequest != nil {
if err := cfg.ModifyRequest(c); err != nil {
return err
}
}
if c.App().Config().Immutable {
req.SetRequestURIBytes(req.RequestURI())
} else {
req.SetRequestURI(utils.UnsafeString(req.RequestURI()))
}
// Forward request
if err := lbc.Do(req, res); err != nil {
return err
}
if !cfg.KeepConnectionHeader {
// Don't proxy "Connection" header
res.Header.Del(fiber.HeaderConnection)
}
// Modify response
if cfg.ModifyResponse != nil {
if err := cfg.ModifyResponse(c); err != nil {
return err
}
}
// Return nil to end proxying if no error
return nil
}
}
var client = &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
}
var (
errNilProxyClientOverride = errors.New("proxy: nil client override passed to Do/Forward")
errNilGlobalProxyClient = errors.New("proxy: global client is nil, set a non-nil client with proxy.WithClient")
)
var lock sync.RWMutex
// WithClient sets the global proxy client.
// This function should be called before Do and Forward.
func WithClient(cli *fasthttp.Client) {
if cli == nil {
panic("proxy: WithClient requires a non-nil *fasthttp.Client")
}
lock.Lock()
defer lock.Unlock()
client = cli
}
// Forward performs the given http request and fills the given http response.
// This method will return a fiber.Handler
func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler {
return func(c fiber.Ctx) error {
return Do(c, addr, clients...)
}
}
// Do performs the given http request and fills the given http response.
// This method can be used within a fiber.Handler
func Do(c fiber.Ctx, addr string, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.Do(req, resp)
}, clients...)
}
// DoRedirects performs the given http request and fills the given http response, following up to maxRedirectsCount redirects.
// When the redirect count exceeds maxRedirectsCount, ErrTooManyRedirects is returned.
// This method can be used within a fiber.Handler
func DoRedirects(c fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoRedirects(req, resp, maxRedirectsCount)
}, clients...)
}
// DoDeadline performs the given request and waits for response until the given deadline.
// This method can be used within a fiber.Handler
func DoDeadline(c fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoDeadline(req, resp, deadline)
}, clients...)
}
// DoTimeout performs the given request and waits for response during the given timeout duration.
// This method can be used within a fiber.Handler
func DoTimeout(c fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoTimeout(req, resp, timeout)
}, clients...)
}
func doAction(
c fiber.Ctx,
addr string,
action func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error,
clients ...*fasthttp.Client,
) error {
lock.RLock()
globalClient := client
lock.RUnlock()
cli, err := selectClient(globalClient, clients...)
if err != nil {
return err
}
req := c.Request()
res := c.Response()
originalURL := utils.CopyString(c.OriginalURL())
defer req.SetRequestURI(originalURL)
copiedURL := utils.CopyString(addr)
req.SetRequestURI(copiedURL)
// NOTE: if req.isTLS is true, SetRequestURI keeps the scheme as https.
// Reference: https://github.com/gofiber/fiber/issues/1762
if scheme := getScheme(utils.UnsafeBytes(copiedURL)); len(scheme) > 0 {
req.URI().SetSchemeBytes(scheme)
}
req.Header.Del(fiber.HeaderConnection)
if err := action(cli, req, res); err != nil {
return err
}
res.Header.Del(fiber.HeaderConnection)
return nil
}
func selectClient(globalClient *fasthttp.Client, clients ...*fasthttp.Client) (*fasthttp.Client, error) {
if len(clients) != 0 {
if clients[0] == nil {
return nil, errNilProxyClientOverride
}
return clients[0], nil
}
if globalClient == nil {
return nil, errNilGlobalProxyClient
}
return globalClient, nil
}
func getScheme(uri []byte) []byte {
i := bytes.IndexByte(uri, '/')
if i < 1 || uri[i-1] != ':' || i == len(uri)-1 || uri[i+1] != '/' {
return nil
}
return uri[:i-1]
}
// DomainForward performs an http request based on the given domain and populates the given http response.
// This method will return a fiber.Handler
func DomainForward(hostname, addr string, clients ...*fasthttp.Client) fiber.Handler {
return func(c fiber.Ctx) error {
host := utils.UnsafeString(c.Request().Host())
if host == hostname {
return Do(c, addr+c.OriginalURL(), clients...)
}
return nil
}
}
type roundrobin struct {
pool []string
current int
sync.Mutex
}
// this method will return a string of addr server from list server.
func (r *roundrobin) get() string {
r.Lock()
defer r.Unlock()
if r.current >= len(r.pool) {
r.current %= len(r.pool)
}
result := r.pool[r.current]
r.current++
return result
}
// BalancerForward Forward performs the given http request with round robin algorithm to server and fills the given http response.
// This method will return a fiber.Handler
func BalancerForward(servers []string, clients ...*fasthttp.Client) fiber.Handler {
r := &roundrobin{
current: 0,
pool: servers,
}
return func(c fiber.Ctx) error {
server := r.get()
if !strings.HasPrefix(server, "http") {
server = "http://" + server
}
c.Request().Header.Add("X-Real-IP", c.IP())
return Do(c, server+c.OriginalURL(), clients...)
}
}
================================================
FILE: middleware/proxy/proxy_test.go
================================================
package proxy
import (
"crypto/tls"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v3"
clientpkg "github.com/gofiber/fiber/v3/client"
"github.com/stretchr/testify/require"
"github.com/gofiber/fiber/v3/internal/tlstest"
"github.com/valyala/fasthttp"
)
func startServer(app *fiber.App, ln net.Listener) {
go func() {
err := app.Listener(ln, fiber.ListenConfig{
DisableStartupMessage: true,
})
if err != nil {
panic(err)
}
}()
}
func createProxyTestServer(t *testing.T, handler fiber.Handler, network, address string) (target *fiber.App, addr string) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming returned target app and address for readability
t.Helper()
target = fiber.New()
target.Get("/", handler)
ln, err := net.Listen(network, address)
require.NoError(t, err)
addr = ln.Addr().String()
startServer(target, ln)
return target, addr
}
func createProxyTestServerIPv4(t *testing.T, handler fiber.Handler) (target *fiber.App, addr string) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming returned target app and address for readability
t.Helper()
return createProxyTestServer(t, handler, fiber.NetworkTCP4, "127.0.0.1:0")
}
func createProxyTestServerIPv6(t *testing.T, handler fiber.Handler) (target *fiber.App, addr string) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming returned target app and address for readability
t.Helper()
return createProxyTestServer(t, handler, fiber.NetworkTCP6, "[::1]:0")
}
func createRedirectServer(t *testing.T) string {
t.Helper()
app := fiber.New()
var addr string
app.Get("/", func(c fiber.Ctx) error {
c.Location("http://" + addr + "/final")
return c.Status(fiber.StatusMovedPermanently).SendString("redirect")
})
app.Get("/final", func(c fiber.Ctx) error {
return c.SendString("final")
})
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
t.Cleanup(func() {
ln.Close() //nolint:errcheck // It is fine to ignore the error here
})
addr = ln.Addr().String()
startServer(app, ln)
return addr
}
// go test -run Test_Proxy_Empty_Host
func Test_Proxy_Empty_Upstream_Servers(t *testing.T) {
t.Parallel()
defer func() {
if r := recover(); r != nil {
if r != "Servers cannot be empty" {
panic(r)
}
}
}()
app := fiber.New()
app.Use(Balancer(Config{Servers: []string{}}))
}
// go test -run Test_Proxy_Empty_Config
func Test_Proxy_Empty_Config(t *testing.T) {
t.Parallel()
defer func() {
if r := recover(); r != nil {
if r != "Servers cannot be empty" {
panic(r)
}
}
}()
app := fiber.New()
app.Use(Balancer(Config{}))
}
// go test -run Test_Proxy_Next
func Test_Proxy_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(Balancer(Config{
Servers: []string{"127.0.0.1"},
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
// go test -run Test_Proxy
func Test_Proxy(t *testing.T) {
t.Parallel()
target, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
app := fiber.New()
app.Use(Balancer(Config{Servers: []string{addr}}))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Host = addr
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
}
// go test -run Test_Proxy_Balancer_WithTlsConfig
func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) {
t.Parallel()
serverTLSConf, _, err := tlstest.GetTLSConfigs()
require.NoError(t, err)
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
ln = tls.NewListener(ln, serverTLSConf)
app := fiber.New()
app.Get("/tlsbalancer", func(c fiber.Ctx) error {
return c.SendString("tls balancer")
})
addr := ln.Addr().String()
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
// disable certificate verification in Balancer
app.Use(Balancer(Config{
Servers: []string{addr},
TLSConfig: clientTLSConf,
}))
startServer(app, ln)
client := clientpkg.New()
client.SetTLSConfig(clientTLSConf)
resp, err := client.Get("https://" + addr + "/tlsbalancer")
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "tls balancer", string(resp.Body()))
resp.Close()
}
// go test -run Test_Proxy_Balancer_IPv6_Upstream
func Test_Proxy_Balancer_IPv6_Upstream(t *testing.T) {
t.Parallel()
target, addr := createProxyTestServerIPv6(t, func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
app := fiber.New()
app.Use(Balancer(Config{Servers: []string{addr}}))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Host = addr
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
}
// go test -run Test_Proxy_Balancer_IPv6_Upstream
func Test_Proxy_Balancer_IPv6_Upstream_With_DialDualStack(t *testing.T) {
t.Parallel()
target, addr := createProxyTestServerIPv6(t, func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
app := fiber.New()
app.Use(Balancer(Config{
Servers: []string{addr},
DialDualStack: true,
}))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Host = addr
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
}
// go test -run Test_Proxy_Balancer_IPv6_Upstream
func Test_Proxy_Balancer_IPv4_Upstream_With_DialDualStack(t *testing.T) {
t.Parallel()
target, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
app := fiber.New()
app.Use(Balancer(Config{
Servers: []string{addr},
DialDualStack: true,
}))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Host = addr
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
}
// go test -run Test_Proxy_Forward_WithTlsConfig_To_Http
func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) {
t.Parallel()
_, targetAddr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendString("hello from target")
})
proxyServerTLSConf, _, err := tlstest.GetTLSConfigs()
require.NoError(t, err)
proxyServerLn, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
proxyServerLn = tls.NewListener(proxyServerLn, proxyServerTLSConf)
proxyAddr := proxyServerLn.Addr().String()
app := fiber.New()
app.Use(Forward("http://" + targetAddr))
startServer(app, proxyServerLn)
client := clientpkg.New()
client.SetTimeout(5 * time.Second)
client.TLSConfig().InsecureSkipVerify = true
resp, err := client.Get("https://" + proxyAddr)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "hello from target", string(resp.Body()))
resp.Close()
}
// go test -run Test_Proxy_Forward
func Test_Proxy_Forward(t *testing.T) {
t.Parallel()
app := fiber.New()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendString("forwarded")
})
app.Use(Forward("http://" + addr))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "forwarded", string(b))
}
// go test -run Test_Proxy_Forward_WithClient_TLSConfig
func Test_Proxy_Forward_WithClient_TLSConfig(t *testing.T) {
t.Parallel()
serverTLSConf, _, err := tlstest.GetTLSConfigs()
require.NoError(t, err)
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
ln = tls.NewListener(ln, serverTLSConf)
app := fiber.New()
app.Get("/tlsfwd", func(c fiber.Ctx) error {
return c.SendString("tls forward")
})
addr := ln.Addr().String()
clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine
// disable certificate verification
WithClient(&fasthttp.Client{
TLSConfig: clientTLSConf,
})
app.Use(Forward("https://" + addr + "/tlsfwd"))
startServer(app, ln)
client := clientpkg.New()
client.SetTLSConfig(clientTLSConf)
resp, err := client.Get("https://" + addr)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "tls forward", string(resp.Body()))
resp.Close()
}
// go test -run Test_Proxy_Modify_Response
func Test_Proxy_Modify_Response(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.Status(500).SendString("not modified")
})
app := fiber.New()
app.Use(Balancer(Config{
Servers: []string{addr},
ModifyResponse: func(c fiber.Ctx) error {
c.Response().SetStatusCode(fiber.StatusOK)
return c.SendString("modified response")
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "modified response", string(b))
}
// go test -run Test_Proxy_Modify_Request
func Test_Proxy_Modify_Request(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
b := c.Request().Body()
return c.SendString(string(b))
})
app := fiber.New()
app.Use(Balancer(Config{
Servers: []string{addr},
ModifyRequest: func(c fiber.Ctx) error {
c.Request().SetBody([]byte("modified request"))
return nil
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "modified request", string(b))
}
// go test -run Test_Proxy_Timeout_Slow_Server
func Test_Proxy_Timeout_Slow_Server(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
time.Sleep(300 * time.Millisecond)
return c.SendString("fiber is awesome")
})
app := fiber.New()
app.Use(Balancer(Config{
Servers: []string{addr},
Timeout: 600 * time.Millisecond,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "fiber is awesome", string(b))
}
// go test -run Test_Proxy_With_Timeout
func Test_Proxy_With_Timeout(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
time.Sleep(1 * time.Second)
return c.SendString("fiber is awesome")
})
app := fiber.New()
app.Use(Balancer(Config{
Servers: []string{addr},
Timeout: 100 * time.Millisecond,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "timeout", string(b))
}
// go test -run Test_Proxy_Buffer_Size_Response
func Test_Proxy_Buffer_Size_Response(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
long := strings.Join(make([]string, 5000), "-")
c.Set("Very-Long-Header", long)
return c.SendString("ok")
})
app := fiber.New()
app.Use(Balancer(Config{Servers: []string{addr}, KeepConnectionHeader: true}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
app = fiber.New()
app.Use(Balancer(Config{
Servers: []string{addr},
ReadBufferSize: 1024 * 8,
}))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
// go test -race -run Test_Proxy_Do_RestoreOriginalURL
func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
return Do(c, "http://"+addr)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err1)
require.Equal(t, "/test", resp.Request.URL.String())
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "proxied", string(body))
}
// go test -race -run Test_Proxy_Do_WithRealURL
func Test_Proxy_Do_WithRealURL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendString("real url")
})
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
return Do(c, "http://"+addr)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err1)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "/test", resp.Request.URL.String())
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "real url", string(body))
}
// go test -race -run Test_Proxy_Do_WithRedirect
func Test_Proxy_Do_WithRedirect(t *testing.T) {
t.Parallel()
addr := createRedirectServer(t)
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
return Do(c, "http://"+addr)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err1)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "redirect", string(body))
require.Equal(t, fiber.StatusMovedPermanently, resp.StatusCode)
}
// go test -race -run Test_Proxy_DoRedirects_RestoreOriginalURL
func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) {
t.Parallel()
addr := createRedirectServer(t)
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
return DoRedirects(c, "http://"+addr, 1)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err1)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "final", string(body))
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoRedirects_TooManyRedirects
func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) {
t.Parallel()
addr := createRedirectServer(t)
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
return DoRedirects(c, "http://"+addr, 0)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err1)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "too many redirects detected when doing the request", string(body))
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Equal(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoTimeout_RestoreOriginalURL
func Test_Proxy_DoTimeout_RestoreOriginalURL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
return DoTimeout(c, "http://"+addr, time.Second)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err1)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "proxied", string(body))
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoTimeout_Timeout
func Test_Proxy_DoTimeout_Timeout(t *testing.T) {
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
time.Sleep(time.Second * 5)
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
return DoTimeout(c, "http://"+addr, time.Second)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "timeout", string(body))
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
require.Equal(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoDeadline_RestoreOriginalURL
func Test_Proxy_DoDeadline_RestoreOriginalURL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second))
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err1)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "proxied", string(body))
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoDeadline_PastDeadline
func Test_Proxy_DoDeadline_PastDeadline(t *testing.T) {
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
time.Sleep(time.Second * 5)
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
return DoDeadline(c, "http://"+addr, time.Now().Add(2*time.Second))
})
_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{
Timeout: 1 * time.Second,
FailOnTimeout: true,
})
require.Equal(t, os.ErrDeadlineExceeded, err1)
}
// go test -race -run Test_Proxy_Do_HTTP_Prefix_URL
func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendString("hello world")
})
app := fiber.New()
app.Get("/*", func(c fiber.Ctx) error {
path := c.OriginalURL()
url := strings.TrimPrefix(path, "/")
require.Equal(t, "http://"+addr, url)
if err := Do(c, url); err != nil {
return err
}
c.Response().Header.Del(fiber.HeaderServer)
return nil
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/http://"+addr, http.NoBody))
require.NoError(t, err)
s, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "hello world", string(s))
}
// go test -race -run Test_Proxy_Forward_Global_Client
func Test_Proxy_Forward_Global_Client(t *testing.T) {
t.Parallel()
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
WithClient(&fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
})
app := fiber.New()
app.Get("/test_global_client", func(c fiber.Ctx) error {
return c.SendString("test_global_client")
})
addr := ln.Addr().String()
app.Use(Forward("http://" + addr + "/test_global_client"))
startServer(app, ln)
client := clientpkg.New()
resp, err := client.Get("http://" + addr)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "test_global_client", string(resp.Body()))
resp.Close()
}
// go test -race -run Test_Proxy_Forward_Local_Client
func Test_Proxy_Forward_Local_Client(t *testing.T) {
t.Parallel()
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
app := fiber.New()
app.Get("/test_local_client", func(c fiber.Ctx) error {
return c.SendString("test_local_client")
})
addr := ln.Addr().String()
app.Use(Forward("http://"+addr+"/test_local_client", &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Dial: fasthttp.Dial,
}))
startServer(app, ln)
client := clientpkg.New()
resp, err := client.Get("http://" + addr)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "test_local_client", string(resp.Body()))
resp.Close()
}
// go test -run Test_Proxy_WithClient_Nil_Panics
func Test_Proxy_WithClient_Nil_Panics(t *testing.T) {
t.Parallel()
require.PanicsWithValue(t, "proxy: WithClient requires a non-nil *fasthttp.Client", func() {
WithClient(nil)
})
}
// go test -run Test_Proxy_Do_NilClientOverride
func Test_Proxy_Do_NilClientOverride(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
return Do(c, "http://"+addr, nil)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, errNilProxyClientOverride.Error(), string(body))
}
// go test -run Test_Proxy_Do_NonNilClientOverride
func Test_Proxy_Do_NonNilClientOverride(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
return Do(c, "http://"+addr, &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
})
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "proxied", string(body))
}
// go test -run Test_Proxy_SelectClient_NilGlobal
func Test_Proxy_SelectClient_NilGlobal(t *testing.T) {
t.Parallel()
selectedClient, err := selectClient(nil)
require.ErrorIs(t, err, errNilGlobalProxyClient)
require.Nil(t, selectedClient)
}
// go test -run Test_Proxy_NilClientOverride_AcrossHelpers
func Test_Proxy_NilClientOverride_AcrossHelpers(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendString("proxied")
})
tests := map[string]func(c fiber.Ctx) error{
"DoRedirects": func(c fiber.Ctx) error {
return DoRedirects(c, "http://"+addr, 1, nil)
},
"DoDeadline": func(c fiber.Ctx) error {
return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second), nil)
},
"DoTimeout": func(c fiber.Ctx) error {
return DoTimeout(c, "http://"+addr, time.Second, nil)
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", run)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, errNilProxyClientOverride.Error(), string(body))
})
}
}
// go test -run Test_ProxyBalancer_Custom_Client
func Test_ProxyBalancer_Custom_Client(t *testing.T) {
t.Parallel()
target, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
app := fiber.New()
app.Use(Balancer(Config{Client: &fasthttp.LBClient{
Clients: []fasthttp.BalancingClient{
&fasthttp.HostClient{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Addr: addr,
},
},
Timeout: time.Second,
}}))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Host = addr
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
}
// go test -run Test_Proxy_Domain_Forward_Local
func Test_Proxy_Domain_Forward_Local(t *testing.T) {
t.Parallel()
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
app := fiber.New()
// target server
ln1, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
require.NoError(t, err)
app1 := fiber.New()
app1.Get("/test", func(c fiber.Ctx) error {
return c.SendString("test_local_client:" + c.Query("query_test"))
})
proxyAddr := ln.Addr().String()
targetAddr := ln1.Addr().String()
localDomain := strings.Replace(proxyAddr, "127.0.0.1", "localhost", 1)
app.Use(DomainForward(localDomain, "http://"+targetAddr, &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Dial: fasthttp.Dial,
}))
startServer(app, ln)
startServer(app1, ln1)
client := clientpkg.New()
resp, err := client.Get("http://" + localDomain + "/test?query_test=true")
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode())
require.Equal(t, "test_local_client:true", string(resp.Body()))
resp.Close()
}
// go test -run Test_Proxy_Balancer_Forward_Local
func Test_Proxy_Balancer_Forward_Local(t *testing.T) {
t.Parallel()
app := fiber.New()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendString("forwarded")
})
app.Use(BalancerForward([]string{addr}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
b, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "forwarded", string(b))
}
func Test_Proxy_Immutable(t *testing.T) {
t.Parallel()
target, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{
Timeout: 2 * time.Second,
FailOnTimeout: true,
})
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
app := fiber.New(fiber.Config{Immutable: true})
app.Use(Balancer(Config{Servers: []string{addr}}))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Host = addr
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
}
func Test_Proxy_KeepConnectionHeader(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
c.Set(fiber.HeaderConnection, "backend")
return c.SendString("ok")
})
app := fiber.New()
app.Use(Balancer(Config{Servers: []string{addr}, KeepConnectionHeader: true}))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Host = addr
req.Header.Set(fiber.HeaderConnection, "keep-alive")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, "backend", resp.Header.Get(fiber.HeaderConnection))
}
func Test_Proxy_DropConnectionHeader(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error {
c.Set(fiber.HeaderConnection, "backend")
return c.SendString("ok")
})
app := fiber.New()
app.Use(Balancer(Config{Servers: []string{addr}}))
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Host = addr
req.Header.Set(fiber.HeaderConnection, "keep-alive")
resp, err := app.Test(req)
require.NoError(t, err)
require.Empty(t, resp.Header.Get(fiber.HeaderConnection))
}
================================================
FILE: middleware/recover/config.go
================================================
package recover //nolint:predeclared // TODO: Rename to some non-builtin
import (
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// PanicHandler defines a function to customize the error produced from a recovered panic/result.
//
// Optional. Default: DefaultPanicHandler
PanicHandler func(c fiber.Ctx, r any) error
// StackTraceHandler defines a function to handle stack trace
//
// Optional. Default: defaultStackTraceHandler
StackTraceHandler func(c fiber.Ctx, e any)
// EnableStackTrace enables handling stack trace
//
// Optional. Default: false
EnableStackTrace bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
EnableStackTrace: false,
StackTraceHandler: defaultStackTraceHandler,
PanicHandler: DefaultPanicHandler,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
if cfg.EnableStackTrace && cfg.StackTraceHandler == nil {
cfg.StackTraceHandler = defaultStackTraceHandler
}
if cfg.PanicHandler == nil {
cfg.PanicHandler = DefaultPanicHandler
}
return cfg
}
================================================
FILE: middleware/recover/recover.go
================================================
package recover //nolint:predeclared // TODO: Rename to some non-builtin
import (
"fmt"
"os"
"runtime/debug"
"github.com/gofiber/fiber/v3"
)
func defaultStackTraceHandler(_ fiber.Ctx, e any) {
fmt.Fprintf(os.Stderr, "panic: %v\n\n%s\n", e, debug.Stack())
}
// DefaultPanicHandler returns r directly if it's an error, and creates a new one with the %v verb otherwise.
func DefaultPanicHandler(_ fiber.Ctx, r any) error {
if err, ok := r.(error); ok {
return err
}
return fmt.Errorf("%v", r)
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler
return func(c fiber.Ctx) (err error) { //nolint:nonamedreturns // Uses recover() to overwrite the error
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Catch panics
defer func() {
if r := recover(); r != nil {
if cfg.EnableStackTrace {
cfg.StackTraceHandler(c, r)
}
// Set error that will call the global error handler
err = cfg.PanicHandler(c, r)
}
}()
// Return err if exist, else move to next handler
return c.Next()
}
}
================================================
FILE: middleware/recover/recover_test.go
================================================
package recover //nolint:predeclared // TODO: Rename to some non-builtin
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
// go test -run Test_Recover
func Test_Recover(t *testing.T) {
t.Parallel()
tests := []struct {
name string
panicVal any
panicHandler func(c fiber.Ctx, r any) error
errorMsg string
}{
{
name: "non-error panic will be handled by default",
panicVal: "Hi, I'm an error!",
panicHandler: nil,
errorMsg: "Hi, I'm an error!",
},
{
name: "error panic will be handled by default",
panicVal: errors.New("hi, I'm an error object"),
panicHandler: nil,
errorMsg: "hi, I'm an error object",
},
{
name: "non-error panic will be handled",
panicVal: "Hi, I'm an error!",
panicHandler: func(c fiber.Ctx, r any) error {
return fmt.Errorf("[RECOVERED]: %w", DefaultPanicHandler(c, r))
},
errorMsg: "[RECOVERED]: Hi, I'm an error!",
},
{
name: "error panic will be handled",
panicVal: errors.New("hi, I'm an error object"),
panicHandler: func(c fiber.Ctx, r any) error {
return fmt.Errorf("[RECOVERED]: %w", DefaultPanicHandler(c, r))
},
errorMsg: "[RECOVERED]: hi, I'm an error object",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
app := fiber.New(fiber.Config{
ErrorHandler: func(c fiber.Ctx, err error) error {
require.Equal(t, tc.errorMsg, err.Error())
return c.SendStatus(fiber.StatusTeapot)
},
})
app.Use(New(Config{PanicHandler: tc.panicHandler}))
app.Get("/panic", func(_ fiber.Ctx) error {
panic(tc.panicVal)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
})
}
}
// go test -run Test_Recover_Next
func Test_Recover_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_Recover_EnableStackTrace(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
EnableStackTrace: true,
}))
app.Get("/panic", func(_ fiber.Ctx) error {
panic("Hi, I'm an error!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
}
================================================
FILE: middleware/redirect/config.go
================================================
package redirect
import (
"regexp"
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// Filter defines a function to skip middleware.
// Optional. Default: nil
Next func(fiber.Ctx) bool
// Rules defines the URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
// Required. Example:
// "/old": "/new",
// "/api/*": "/$1",
// "/js/*": "/public/javascript/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
Rules map[string]string
rulesRegex map[*regexp.Regexp]string
// The status code when redirecting
// This is ignored if Redirect is disabled
// Optional. Default: 302 Temporary Redirect
StatusCode int
}
// ConfigDefault is the default config
var ConfigDefault = Config{
StatusCode: fiber.StatusFound,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.StatusCode == 0 {
cfg.StatusCode = ConfigDefault.StatusCode
}
return cfg
}
================================================
FILE: middleware/redirect/redirect.go
================================================
package redirect
import (
"regexp"
"strconv"
"strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)
// Initialize
cfg.rulesRegex = map[*regexp.Regexp]string{}
for k, v := range cfg.Rules {
k = strings.ReplaceAll(k, "*", "(.*)")
k += "$"
cfg.rulesRegex[regexp.MustCompile(k)] = v
}
// Middleware function
return func(c fiber.Ctx) error {
// Next request to skip middleware
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Rewrite
for k, v := range cfg.rulesRegex {
replacer := captureTokens(k, c.Path())
if replacer != nil {
queryString := utils.UnsafeString(c.RequestCtx().QueryArgs().QueryString())
if queryString != "" {
queryString = "?" + queryString
}
return c.Redirect().Status(cfg.StatusCode).To(replacer.Replace(v) + queryString)
}
}
return c.Next()
}
}
// https://github.com/labstack/echo/blob/master/middleware/rewrite.go
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
if len(input) > 1 {
input = utils.TrimRight(input, '/')
}
groups := pattern.FindAllStringSubmatch(input, -1)
if groups == nil {
return nil
}
values := groups[0][1:]
replace := make([]string, 2*len(values))
for i, v := range values {
j := 2 * i
replace[j] = "$" + strconv.Itoa(i+1)
replace[j+1] = v
}
return strings.NewReplacer(replace...)
}
================================================
FILE: middleware/redirect/redirect_test.go
================================================
package redirect
import (
"context"
"net/http"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
func Test_Redirect(t *testing.T) {
app := *fiber.New()
app.Use(New(Config{
Rules: map[string]string{
"/default": "google.com",
},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Use(New(Config{
Rules: map[string]string{
"/default/*": "fiber.wiki",
},
StatusCode: fiber.StatusTemporaryRedirect,
}))
app.Use(New(Config{
Rules: map[string]string{
"/redirect/*": "$1",
},
StatusCode: fiber.StatusSeeOther,
}))
app.Use(New(Config{
Rules: map[string]string{
"/pattern/*": "golang.org",
},
StatusCode: fiber.StatusFound,
}))
app.Use(New(Config{
Rules: map[string]string{
"/": "/swagger",
},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Use(New(Config{
Rules: map[string]string{
"/params": "/with_params",
},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Get("/api/*", func(c fiber.Ctx) error {
return c.SendString("API")
})
app.Get("/new", func(c fiber.Ctx) error {
return c.SendString("Hello, World!")
})
tests := []struct {
name string
url string
redirectTo string
statusCode int
}{
{
name: "should be returns status StatusFound without a wildcard",
url: "/default",
redirectTo: "google.com",
statusCode: fiber.StatusMovedPermanently,
},
{
name: "should be returns status StatusTemporaryRedirect using wildcard",
url: "/default/xyz",
redirectTo: "fiber.wiki",
statusCode: fiber.StatusTemporaryRedirect,
},
{
name: "should be returns status StatusSeeOther without set redirectTo to use the default",
url: "/redirect/github.com/gofiber/redirect",
redirectTo: "github.com/gofiber/redirect",
statusCode: fiber.StatusSeeOther,
},
{
name: "should return the status code default",
url: "/pattern/xyz",
redirectTo: "golang.org",
statusCode: fiber.StatusFound,
},
{
name: "access URL without rule",
url: "/new",
statusCode: fiber.StatusOK,
},
{
name: "redirect to swagger route",
url: "/",
redirectTo: "/swagger",
statusCode: fiber.StatusMovedPermanently,
},
{
name: "no redirect to swagger route",
url: "/api/",
statusCode: fiber.StatusOK,
},
{
name: "no redirect to swagger route #2",
url: "/api/test",
statusCode: fiber.StatusOK,
},
{
name: "redirect with query params",
url: "/params?query=abc",
redirectTo: "/with_params?query=abc",
statusCode: fiber.StatusMovedPermanently,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, tt.url, http.NoBody)
require.NoError(t, err)
req.Header.Set("Location", "github.com/gofiber/redirect")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, tt.statusCode, resp.StatusCode)
require.Equal(t, tt.redirectTo, resp.Header.Get("Location"))
})
}
}
func Test_Next(t *testing.T) {
// Case 1 : Next function always returns true
app := *fiber.New()
app.Use(New(Config{
Next: func(fiber.Ctx) bool {
return true
},
Rules: map[string]string{
"/default": "google.com",
},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// Case 2 : Next function always returns false
app = *fiber.New()
app.Use(New(Config{
Next: func(fiber.Ctx) bool {
return false
},
Rules: map[string]string{
"/default": "google.com",
},
StatusCode: fiber.StatusMovedPermanently,
}))
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusMovedPermanently, resp.StatusCode)
require.Equal(t, "google.com", resp.Header.Get("Location"))
}
func Test_NoRules(t *testing.T) {
// Case 1: No rules with default route defined
app := *fiber.New()
app.Use(New(Config{
StatusCode: fiber.StatusMovedPermanently,
}))
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// Case 2: No rules and no default route defined
app = *fiber.New()
app.Use(New(Config{
StatusCode: fiber.StatusMovedPermanently,
}))
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
func Test_DefaultConfig(t *testing.T) {
// Case 1: Default config and no default route
app := *fiber.New()
app.Use(New())
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
// Case 2: Default config and default route
app = *fiber.New()
app.Use(New())
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
func Test_RegexRules(t *testing.T) {
// Case 1: Rules regex is empty
app := *fiber.New()
app.Use(New(Config{
Rules: map[string]string{},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
// Case 2: Rules regex map contains valid regex and well-formed replacement URLs
app = *fiber.New()
app.Use(New(Config{
Rules: map[string]string{
"/default": "google.com",
},
StatusCode: fiber.StatusMovedPermanently,
}))
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusMovedPermanently, resp.StatusCode)
require.Equal(t, "google.com", resp.Header.Get("Location"))
// Case 3: Test invalid regex throws panic
app = *fiber.New()
require.Panics(t, func() {
app.Use(New(Config{
Rules: map[string]string{
"(": "google.com",
},
StatusCode: fiber.StatusMovedPermanently,
}))
})
}
================================================
FILE: middleware/requestid/config.go
================================================
package requestid
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// Generator defines a function to generate the unique identifier.
//
// Optional. Default: utils.SecureToken
Generator func() string
// Header is the header key where to get/set the unique request ID
//
// Optional. Default: "X-Request-ID"
Header string
}
// ConfigDefault is the default config
// It uses a secure token generator for better privacy and security.
var ConfigDefault = Config{
Next: nil,
Header: fiber.HeaderXRequestID,
Generator: utils.SecureToken,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Header == "" {
cfg.Header = ConfigDefault.Header
}
if cfg.Generator == nil {
cfg.Generator = ConfigDefault.Generator
}
return cfg
}
================================================
FILE: middleware/requestid/requestid.go
================================================
package requestid
import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
)
// The contextKey type is unexported to prevent collisions with context keys defined in
// other packages.
type contextKey int
// The keys for the values in context
const (
requestIDKey contextKey = iota
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
rid := sanitizeRequestID(c.Get(cfg.Header), cfg.Generator)
// Set new id to response header
c.Set(cfg.Header, rid)
// Add the request ID to locals
fiber.StoreInContext(c, requestIDKey, rid)
// Continue stack
return c.Next()
}
}
// sanitizeRequestID returns the provided request ID when it is valid, otherwise
// it tries up to three values from the configured generator, then falls back to SecureToken.
func sanitizeRequestID(rid string, generator func() string) string {
if isValidRequestID(rid) {
return rid
}
for range 3 {
rid = generator()
if isValidRequestID(rid) {
return rid
}
}
// Final fallback: SecureToken always produces a valid ID
return utils.SecureToken()
}
// isValidRequestID reports whether the request ID contains only visible ASCII
// characters (0x20–0x7E) and is non-empty.
func isValidRequestID(rid string) bool {
if rid == "" {
return false
}
for i := 0; i < len(rid); i++ {
c := rid[i]
if c < 0x20 || c > 0x7e {
return false
}
}
return true
}
// FromContext returns the request ID from context.
// It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context.
// If there is no request ID, an empty string is returned.
func FromContext(ctx any) string {
if rid, ok := fiber.ValueFromContext[string](ctx, requestIDKey); ok {
return rid
}
return ""
}
================================================
FILE: middleware/requestid/requestid_test.go
================================================
package requestid
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
// go test -run Test_RequestID
func Test_RequestID(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Hello, World 👋!")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
reqid := resp.Header.Get(fiber.HeaderXRequestID)
require.Len(t, reqid, 43)
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Add(fiber.HeaderXRequestID, reqid)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, reqid, resp.Header.Get(fiber.HeaderXRequestID))
}
func Test_RequestID_InvalidHeaderValue(t *testing.T) {
t.Parallel()
rid := sanitizeRequestID("bad\r\nid", func() string {
return "clean-generated-id"
})
require.Equal(t, "clean-generated-id", rid)
}
func Test_RequestID_InvalidGeneratedValue(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return "bad\r\nid"
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
rid := resp.Header.Get(fiber.HeaderXRequestID)
require.NotEmpty(t, rid)
require.NotContains(t, rid, "\r")
require.NotContains(t, rid, "\n")
require.Len(t, rid, 43, "Fallback should produce a SecureToken")
}
func Test_RequestID_GeneratorAlwaysInvalid(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return "invalid\x00id" // Always invalid due to null byte
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
rid := resp.Header.Get(fiber.HeaderXRequestID)
require.NotEmpty(t, rid)
require.Len(t, rid, 43, "Should fall back to SecureToken after 3 invalid attempts")
}
func Test_RequestID_CustomGenerator(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return "custom-valid-id"
},
}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
rid := resp.Header.Get(fiber.HeaderXRequestID)
require.Equal(t, "custom-valid-id", rid)
}
func Test_isValidRequestID_VisibleASCII(t *testing.T) {
t.Parallel()
require.True(t, isValidRequestID("request-id-09AZaz ~"))
}
func Test_isValidRequestID_Boundaries(t *testing.T) {
t.Parallel()
t.Run("allows space and tilde", func(t *testing.T) {
t.Parallel()
require.True(t, isValidRequestID(" ~"))
})
t.Run("rejects out of range", func(t *testing.T) {
t.Parallel()
require.False(t, isValidRequestID(string([]byte{0x1f})))
require.False(t, isValidRequestID(string([]byte{0x7f})))
})
t.Run("rejects empty", func(t *testing.T) {
t.Parallel()
require.False(t, isValidRequestID(""))
})
}
func Test_isValidRequestID_RejectsObsText(t *testing.T) {
t.Parallel()
require.False(t, isValidRequestID("valid\xff"))
}
// go test -run Test_RequestID_Next
func Test_RequestID_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Empty(t, resp.Header.Get(fiber.HeaderXRequestID))
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
// go test -run Test_RequestID_Locals
func Test_RequestID_FromContext(t *testing.T) {
t.Parallel()
reqID := "ThisIsARequestId"
app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return reqID
},
}))
var ctxVal string
app.Use(func(c fiber.Ctx) error {
ctxVal = FromContext(c)
return c.Next()
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, reqID, ctxVal)
}
func Test_RequestID_FromContext_Empty(t *testing.T) {
t.Parallel()
app := fiber.New()
// No middleware
app.Use(func(c fiber.Ctx) error {
ctxVal := FromContext(c)
require.Empty(t, ctxVal)
return c.SendStatus(fiber.StatusOK)
})
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
}
func Test_RequestID_FromContext_Types(t *testing.T) {
t.Parallel()
reqID := "request-id-123"
app := fiber.New(fiber.Config{PassLocalsToContext: true})
app.Use(New(Config{
Generator: func() string {
return reqID
},
}))
app.Get("/", func(c fiber.Ctx) error {
require.Equal(t, reqID, FromContext(c))
customCtx, ok := c.(fiber.CustomCtx)
require.True(t, ok)
require.Equal(t, reqID, FromContext(customCtx))
require.Equal(t, reqID, FromContext(c.RequestCtx()))
require.Equal(t, reqID, FromContext(c.Context()))
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
================================================
FILE: middleware/responsetime/config.go
================================================
package responsetime
import (
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// Header is the header key used to set the response time.
//
// Optional. Default: "X-Response-Time"
Header string
}
// ConfigDefault is the default config.
var ConfigDefault = Config{
Next: nil,
Header: fiber.HeaderXResponseTime,
}
// Helper function to set default values.
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.Header == "" {
cfg.Header = ConfigDefault.Header
}
return cfg
}
================================================
FILE: middleware/responsetime/responsetime.go
================================================
package responsetime
import (
"time"
"github.com/gofiber/fiber/v3"
)
// New creates a new middleware handler.
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
start := time.Now()
err := c.Next()
c.Set(cfg.Header, time.Since(start).String())
return err
}
}
================================================
FILE: middleware/responsetime/responsetime_test.go
================================================
package responsetime
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
)
func TestResponseTimeMiddleware(t *testing.T) {
t.Parallel()
boom := errors.New("boom")
tests := []struct {
name string
expectedStatus int
useCustomErrorHandler bool
returnError bool
expectHeader bool
skipWithNext bool
}{
{
name: "sets duration header",
expectedStatus: fiber.StatusOK,
expectHeader: true,
},
{
name: "skips when Next returns true",
expectedStatus: fiber.StatusOK,
expectHeader: false,
skipWithNext: true,
},
{
name: "propagates errors",
expectedStatus: fiber.StatusTeapot,
useCustomErrorHandler: true,
returnError: true,
expectHeader: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
configs := []Config(nil)
if tt.skipWithNext {
configs = []Config{{
Next: func(fiber.Ctx) bool {
return true
},
}}
}
appConfig := fiber.Config{}
if tt.useCustomErrorHandler {
appConfig.ErrorHandler = func(c fiber.Ctx, err error) error {
t.Helper()
require.ErrorIs(t, err, boom)
return c.Status(fiber.StatusTeapot).SendString(err.Error())
}
}
app := fiber.New(appConfig)
app.Use(New(configs...))
app.Get("/", func(c fiber.Ctx) error {
if tt.returnError {
return boom
}
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, tt.expectedStatus, resp.StatusCode)
header := resp.Header.Get(fiber.HeaderXResponseTime)
if tt.expectHeader {
require.NotEmpty(t, header)
_, parseErr := time.ParseDuration(header)
require.NoError(t, parseErr)
return
}
require.Empty(t, header)
})
}
}
================================================
FILE: middleware/rewrite/config.go
================================================
package rewrite
import (
"regexp"
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip middleware.
// Optional. Default: nil
Next func(fiber.Ctx) bool
// Rules defines the URL path rewrite rules. The values captured in asterisk can be
// retrieved by index e.g. $1, $2 and so on.
// Required. Example:
// "/old": "/new",
// "/api/*": "/$1",
// "/js/*": "/public/javascript/$1",
// "/users/*/orders/*": "/user/$1/order/$2",
Rules map[string]string
rulesRegex map[*regexp.Regexp]string
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return Config{}
}
// Override default config
cfg := config[0]
return cfg
}
================================================
FILE: middleware/rewrite/rewrite.go
================================================
package rewrite
import (
"regexp"
"strconv"
"strings"
"github.com/gofiber/fiber/v3"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
cfg := configDefault(config...)
// Initialize
cfg.rulesRegex = map[*regexp.Regexp]string{}
for k, v := range cfg.Rules {
k = strings.ReplaceAll(k, "*", "(.*)")
k += "$"
cfg.rulesRegex[regexp.MustCompile(k)] = v
}
// Middleware function
return func(c fiber.Ctx) error {
// Next request to skip middleware
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Rewrite
for k, v := range cfg.rulesRegex {
replacer := captureTokens(k, c.Path())
if replacer != nil {
c.Path(replacer.Replace(v))
break
}
}
return c.Next()
}
}
// https://github.com/labstack/echo/blob/master/middleware/rewrite.go
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
groups := pattern.FindAllStringSubmatch(input, -1)
if groups == nil {
return nil
}
values := groups[0][1:]
replace := make([]string, 2*len(values))
for i, v := range values {
j := 2 * i
replace[j] = "$" + strconv.Itoa(i+1)
replace[j+1] = v
}
return strings.NewReplacer(replace...)
}
================================================
FILE: middleware/rewrite/rewrite_test.go
================================================
package rewrite
import (
"context"
"fmt"
"io"
"net/http"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func Test_New(t *testing.T) {
// Test with no config
m := New()
if m == nil {
t.Error("Expected middleware to be returned, got nil")
}
// Test with config
m = New(Config{
Rules: map[string]string{
"/old": "/new",
},
})
if m == nil {
t.Error("Expected middleware to be returned, got nil")
}
// Test with full config
m = New(Config{
Next: func(fiber.Ctx) bool {
return true
},
Rules: map[string]string{
"/old": "/new",
},
})
if m == nil {
t.Error("Expected middleware to be returned, got nil")
}
}
func Test_Rewrite(t *testing.T) {
// Case 1: Next function always returns true
app := fiber.New()
app.Use(New(Config{
Next: func(fiber.Ctx) bool {
return true
},
Rules: map[string]string{
"/old": "/new",
},
}))
app.Get("/old", func(c fiber.Ctx) error {
return c.SendString("Rewrite Successful")
})
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/old", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
bodyString := string(body)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "Rewrite Successful", bodyString)
// Case 2: Next function always returns false
app = fiber.New()
app.Use(New(Config{
Next: func(fiber.Ctx) bool {
return false
},
Rules: map[string]string{
"/old": "/new",
},
}))
app.Get("/new", func(c fiber.Ctx) error {
return c.SendString("Rewrite Successful")
})
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/old", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
bodyString = string(body)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "Rewrite Successful", bodyString)
// Case 3: check for captured tokens in rewrite rule
app = fiber.New()
app.Use(New(Config{
Rules: map[string]string{
"/users/*/orders/*": "/user/$1/order/$2",
},
}))
app.Get("/user/:userID/order/:orderID", func(c fiber.Ctx) error {
return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID")))
})
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/users/123/orders/456", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
bodyString = string(body)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "User ID: 123, Order ID: 456", bodyString)
// Case 4: Send non-matching request, handled by default route
app = fiber.New()
app.Use(New(Config{
Rules: map[string]string{
"/users/*/orders/*": "/user/$1/order/$2",
},
}))
app.Get("/user/:userID/order/:orderID", func(c fiber.Ctx) error {
return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID")))
})
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/not-matching-any-rule", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
bodyString = string(body)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
require.Equal(t, "OK", bodyString)
// Case 4: Send non-matching request, with no default route
app = fiber.New()
app.Use(New(Config{
Rules: map[string]string{
"/users/*/orders/*": "/user/$1/order/$2",
},
}))
app.Get("/user/:userID/order/:orderID", func(c fiber.Ctx) error {
return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID")))
})
req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/not-matching-any-rule", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}
func Benchmark_Rewrite(b *testing.B) {
// Helper function to create a new Fiber app with rewrite middleware
createApp := func(config Config) *fiber.App {
app := fiber.New()
app.Use(New(config))
return app
}
// Benchmark: Rewrite with Next function always returns true
b.Run("Next always true", func(b *testing.B) {
app := createApp(Config{
Next: func(fiber.Ctx) bool {
return true
},
Rules: map[string]string{
"/old": "/new",
},
})
reqCtx := &fasthttp.RequestCtx{}
reqCtx.Request.SetRequestURI("/old")
b.ReportAllocs()
for b.Loop() {
app.Handler()(reqCtx)
}
})
// Benchmark: Rewrite with Next function always returns false
b.Run("Next always false", func(b *testing.B) {
app := createApp(Config{
Next: func(fiber.Ctx) bool {
return false
},
Rules: map[string]string{
"/old": "/new",
},
})
reqCtx := &fasthttp.RequestCtx{}
reqCtx.Request.SetRequestURI("/old")
b.ReportAllocs()
for b.Loop() {
app.Handler()(reqCtx)
}
})
// Benchmark: Rewrite with tokens
b.Run("Rewrite with tokens", func(b *testing.B) {
app := createApp(Config{
Rules: map[string]string{
"/users/*/orders/*": "/user/$1/order/$2",
},
})
reqCtx := &fasthttp.RequestCtx{}
reqCtx.Request.SetRequestURI("/users/123/orders/456")
b.ReportAllocs()
for b.Loop() {
app.Handler()(reqCtx)
}
})
// Benchmark: Non-matching request, handled by default route
b.Run("NonMatch with default", func(b *testing.B) {
app := createApp(Config{
Rules: map[string]string{
"/users/*/orders/*": "/user/$1/order/$2",
},
})
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
reqCtx := &fasthttp.RequestCtx{}
reqCtx.Request.SetRequestURI("/not-matching-any-rule")
b.ReportAllocs()
for b.Loop() {
app.Handler()(reqCtx)
}
})
// Benchmark: Non-matching request, with no default route
b.Run("NonMatch without default", func(b *testing.B) {
app := createApp(Config{
Rules: map[string]string{
"/users/*/orders/*": "/user/$1/order/$2",
},
})
reqCtx := &fasthttp.RequestCtx{}
reqCtx.Request.SetRequestURI("/not-matching-any-rule")
b.ReportAllocs()
for b.Loop() {
app.Handler()(reqCtx)
}
})
}
func Benchmark_Rewrite_Parallel(b *testing.B) {
// Helper function to create a new Fiber app with rewrite middleware
createApp := func(config Config) *fiber.App {
app := fiber.New()
app.Use(New(config))
return app
}
// Parallel Benchmark: Rewrite with Next function always returns true
b.Run("Next always true", func(b *testing.B) {
app := createApp(Config{
Next: func(fiber.Ctx) bool {
return true
},
Rules: map[string]string{
"/old": "/new",
},
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
reqCtx := &fasthttp.RequestCtx{}
reqCtx.Request.SetRequestURI("/old")
for pb.Next() {
app.Handler()(reqCtx)
}
})
})
// Parallel Benchmark: Rewrite with Next function always returns false
b.Run("Next always false", func(b *testing.B) {
app := createApp(Config{
Next: func(fiber.Ctx) bool {
return false
},
Rules: map[string]string{
"/old": "/new",
},
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
reqCtx := &fasthttp.RequestCtx{}
reqCtx.Request.SetRequestURI("/old")
for pb.Next() {
app.Handler()(reqCtx)
}
})
})
// Parallel Benchmark: Rewrite with tokens
b.Run("Rewrite with tokens", func(b *testing.B) {
app := createApp(Config{
Rules: map[string]string{
"/users/*/orders/*": "/user/$1/order/$2",
},
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
reqCtx := &fasthttp.RequestCtx{}
reqCtx.Request.SetRequestURI("/users/123/orders/456")
for pb.Next() {
app.Handler()(reqCtx)
}
})
})
// Parallel Benchmark: Non-matching request, handled by default route
b.Run("NonMatch with default", func(b *testing.B) {
app := createApp(Config{
Rules: map[string]string{
"/users/*/orders/*": "/user/$1/order/$2",
},
})
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
reqCtx := &fasthttp.RequestCtx{}
reqCtx.Request.SetRequestURI("/not-matching-any-rule")
for pb.Next() {
app.Handler()(reqCtx)
}
})
})
// Parallel Benchmark: Non-matching request, with no default route
b.Run("NonMatch without default", func(b *testing.B) {
app := createApp(Config{
Rules: map[string]string{
"/users/*/orders/*": "/user/$1/order/$2",
},
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
reqCtx := &fasthttp.RequestCtx{}
reqCtx.Request.SetRequestURI("/not-matching-any-rule")
for pb.Next() {
app.Handler()(reqCtx)
}
})
})
}
================================================
FILE: middleware/session/config.go
================================================
package session
import (
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
"github.com/gofiber/fiber/v3/log"
"github.com/gofiber/utils/v2"
)
// Config defines the configuration for the session middleware.
type Config struct {
// Storage interface for storing session data.
//
// Optional. Default: memory.New()
Storage fiber.Storage
// Store defines the session store.
//
// Required.
Store *Store
// Next defines a function to skip this middleware when it returns true.
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// ErrorHandler defines a function to handle errors.
//
// Optional. Default: nil
ErrorHandler func(fiber.Ctx, error)
// KeyGenerator generates the session key.
//
// Optional. Default: utils.SecureToken
KeyGenerator func() string
// CookieDomain defines the domain of the session cookie.
//
// Optional. Default: ""
CookieDomain string
// CookiePath defines the path of the session cookie.
//
// Optional. Default: ""
CookiePath string
// CookieSameSite specifies the SameSite attribute of the cookie.
//
// Optional. Default: "Lax"
CookieSameSite string
// Extractor is used to extract the session ID from the request.
// See: https://docs.gofiber.io/guide/extractors
//
// Optional. Default: extractors.FromCookie("session_id")
Extractor extractors.Extractor
// IdleTimeout defines the maximum duration of inactivity before the session expires.
//
// Note: The idle timeout is updated on each `Save()` call. If a middleware handler is used, `Save()` is called automatically.
//
// Optional. Default: 30 * time.Minute
IdleTimeout time.Duration
// AbsoluteTimeout defines the maximum duration of the session before it expires.
//
// If set to 0, the session will not have an absolute timeout, and will expire after the idle timeout.
//
// Optional. Default: 0
AbsoluteTimeout time.Duration
// CookieSecure specifies if the session cookie should be secure.
//
// Optional. Default: false
CookieSecure bool
// CookieHTTPOnly specifies if the session cookie should be HTTP-only.
//
// Optional. Default: false
CookieHTTPOnly bool
// CookieSessionOnly determines if the cookie should expire when the browser session ends.
//
// If true, the cookie will be deleted when the browser is closed.
// Note: This will not delete the session data from the store.
//
// Optional. Default: false
CookieSessionOnly bool
}
// ConfigDefault provides the default configuration.
var ConfigDefault = Config{
IdleTimeout: 30 * time.Minute,
KeyGenerator: utils.SecureToken,
Extractor: extractors.FromCookie("session_id"),
CookieSameSite: "Lax",
}
// DefaultErrorHandler logs the error and sends a 500 status code.
//
// Parameters:
// - c: The Fiber context.
// - err: The error to handle.
//
// Usage:
//
// DefaultErrorHandler(c, err)
func DefaultErrorHandler(c fiber.Ctx, err error) {
log.Errorf("session error: %v", err)
if sendErr := c.Status(fiber.StatusInternalServerError).SendString("Internal Server Error"); sendErr != nil {
log.Errorf("failed to send error response: %v", sendErr)
}
}
// configDefault sets default values for the Config struct.
//
// This function ensures that all necessary fields have sensible defaults
// if they are not explicitly set by the user.
//
// Parameters:
// - config: Variadic parameter to override default config.
//
// Returns:
// - Config: The configuration with defaults applied.
//
// Usage:
//
// cfg := configDefault()
// cfg := configDefault(customConfig)
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if cfg.IdleTimeout <= 0 {
cfg.IdleTimeout = ConfigDefault.IdleTimeout
}
// Ensure AbsoluteTimeout is greater than or equal to IdleTimeout.
if cfg.AbsoluteTimeout > 0 && cfg.AbsoluteTimeout < cfg.IdleTimeout {
panic("[session] AbsoluteTimeout must be greater than or equal to IdleTimeout")
}
// Check if we have a zero-value Extractor
if cfg.Extractor.Extract == nil {
cfg.Extractor = ConfigDefault.Extractor
}
if cfg.KeyGenerator == nil {
cfg.KeyGenerator = ConfigDefault.KeyGenerator
}
if cfg.CookieSameSite == "" {
cfg.CookieSameSite = ConfigDefault.CookieSameSite
}
return cfg
}
================================================
FILE: middleware/session/config_test.go
================================================
package session
import (
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func TestConfigDefault(t *testing.T) {
// Test default config
cfg := configDefault()
require.Equal(t, 30*time.Minute, cfg.IdleTimeout)
require.NotNil(t, cfg.KeyGenerator)
require.NotNil(t, cfg.Extractor)
require.Equal(t, "session_id", cfg.Extractor.Key)
require.Equal(t, "Lax", cfg.CookieSameSite)
}
func TestConfigDefaultWithCustomConfig(t *testing.T) {
// Test custom config
customConfig := Config{
IdleTimeout: 48 * time.Hour,
Extractor: extractors.FromHeader("X-Custom-Session"),
KeyGenerator: func() string { return "custom_key" },
}
cfg := configDefault(customConfig)
require.Equal(t, 48*time.Hour, cfg.IdleTimeout)
require.NotNil(t, cfg.KeyGenerator)
require.NotNil(t, cfg.Extractor)
require.Equal(t, "X-Custom-Session", cfg.Extractor.Key)
}
func TestDefaultErrorHandler(t *testing.T) {
// Create a new Fiber app
app := fiber.New()
// Create a new context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// Test DefaultErrorHandler
DefaultErrorHandler(ctx, fiber.ErrInternalServerError)
require.Equal(t, fiber.StatusInternalServerError, ctx.Response().StatusCode())
}
func TestAbsoluteTimeoutValidation(t *testing.T) {
require.PanicsWithValue(t, "[session] AbsoluteTimeout must be greater than or equal to IdleTimeout", func() {
configDefault(Config{
IdleTimeout: 30 * time.Minute,
AbsoluteTimeout: 15 * time.Minute, // Less than IdleTimeout
})
})
}
================================================
FILE: middleware/session/data.go
================================================
package session
import (
"sync"
)
// msgp -file="data.go" -o="data_msgp.go" -tests=true -unexported
// Session state should remain small to fit common storage payload limits.
//
//go:generate msgp -o=data_msgp.go -tests=true -unexported
//msgp:ignore data
type data struct {
Data map[any]any // Session key counts are expected to be bounded.
sync.RWMutex `msg:"-"`
}
var dataPool = sync.Pool{
New: func() any {
d := new(data)
d.Data = make(map[any]any)
return d
},
}
// acquireData returns a new data object from the pool.
//
// Returns:
// - *data: The data object.
//
// Usage:
//
// d := acquireData()
func acquireData() *data {
obj := dataPool.Get()
if d, ok := obj.(*data); ok {
return d
}
// Handle unexpected type in the pool
panic("unexpected type in data pool")
}
// Reset clears the data map and resets the data object.
//
// Usage:
//
// d.Reset()
func (d *data) Reset() {
d.Lock()
defer d.Unlock()
clear(d.Data)
}
// Get retrieves a value from the data map by key.
//
// Parameters:
// - key: The key to retrieve.
//
// Returns:
// - any: The value associated with the key.
//
// Usage:
//
// value := d.Get("key")
func (d *data) Get(key any) any {
d.RLock()
defer d.RUnlock()
return d.Data[key]
}
// Set updates or creates a new key-value pair in the data map.
//
// Parameters:
// - key: The key to set.
// - value: The value to set.
//
// Usage:
//
// d.Set("key", "value")
func (d *data) Set(key, value any) {
d.Lock()
defer d.Unlock()
d.Data[key] = value
}
// Delete removes a key-value pair from the data map.
//
// Parameters:
// - key: The key to delete.
//
// Usage:
//
// d.Delete("key")
func (d *data) Delete(key any) {
d.Lock()
defer d.Unlock()
delete(d.Data, key)
}
// Keys retrieves all keys in the data map.
//
// Returns:
// - []any: A slice of all keys in the data map.
//
// Usage:
//
// keys := d.Keys()
func (d *data) Keys() []any {
d.RLock()
defer d.RUnlock()
keys := make([]any, 0, len(d.Data))
for k := range d.Data {
keys = append(keys, k)
}
return keys
}
// Len returns the number of key-value pairs in the data map.
//
// Returns:
// - int: The number of key-value pairs.
//
// Usage:
//
// length := d.Len()
func (d *data) Len() int {
d.RLock()
defer d.RUnlock()
return len(d.Data)
}
================================================
FILE: middleware/session/data_msgp.go
================================================
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
package session
================================================
FILE: middleware/session/data_msgp_test.go
================================================
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
package session
================================================
FILE: middleware/session/data_test.go
================================================
package session
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestKeys(t *testing.T) {
t.Parallel()
// Test case: Empty data
t.Run("Empty data", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
defer d.Reset()
keys := d.Keys()
require.Empty(t, keys, "Expected no keys in empty data")
})
// Test case: Single key
t.Run("Single key", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
keys := d.Keys()
require.Len(t, keys, 1, "Expected one key")
require.Contains(t, keys, "key1", "Expected key1 to be present")
})
// Test case: Multiple keys
t.Run("Multiple keys", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
d.Set("key2", "value2")
d.Set("key3", "value3")
keys := d.Keys()
require.Len(t, keys, 3, "Expected three keys")
require.Contains(t, keys, "key1", "Expected key1 to be present")
require.Contains(t, keys, "key2", "Expected key2 to be present")
require.Contains(t, keys, "key3", "Expected key3 to be present")
})
// Test case: Concurrent access
t.Run("Concurrent access", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
d.Set("key2", "value2")
d.Set("key3", "value3")
done := make(chan bool)
go func() {
keys := d.Keys()
assert.Len(t, keys, 3, "Expected three keys")
done <- true
}()
go func() {
keys := d.Keys()
assert.Len(t, keys, 3, "Expected three keys")
done <- true
}()
<-done
<-done
})
}
func TestData_Len(t *testing.T) {
t.Parallel()
// Test case: Empty data
t.Run("Empty data", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
defer d.Reset()
length := d.Len()
require.Equal(t, 0, length, "Expected length to be 0 for empty data")
})
// Test case: Single key
t.Run("Single key", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
length := d.Len()
require.Equal(t, 1, length, "Expected length to be 1 when one key is set")
})
// Test case: Multiple keys
t.Run("Multiple keys", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
d.Set("key2", "value2")
d.Set("key3", "value3")
length := d.Len()
require.Equal(t, 3, length, "Expected length to be 3 when three keys are set")
})
// Test case: Concurrent access
t.Run("Concurrent access", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
d.Set("key2", "value2")
d.Set("key3", "value3")
done := make(chan bool, 2) // Buffered channel with size 2
go func() {
length := d.Len()
assert.Equal(t, 3, length, "Expected length to be 3 during concurrent access")
done <- true
}()
go func() {
length := d.Len()
assert.Equal(t, 3, length, "Expected length to be 3 during concurrent access")
done <- true
}()
<-done
<-done
})
}
func TestData_Get(t *testing.T) {
t.Parallel()
// Test case: Nonexistent key
t.Run("Nonexistent key", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
defer d.Reset()
value := d.Get("nonexistent-key")
require.Nil(t, value, "Expected nil for nonexistent key")
})
// Test case: Existing key
t.Run("Existing key", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
value := d.Get("key1")
require.Equal(t, "value1", value, "Expected value1 for key1")
})
}
func TestData_Reset(t *testing.T) {
t.Parallel()
// Test case: Reset data
t.Run("Reset data", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
d.Set("key1", "value1")
d.Set("key2", "value2")
d.Reset()
requireDataEmpty(t, d, "Expected data map to be empty after reset")
})
}
func mapPointer(m map[any]any) uintptr {
return reflect.ValueOf(m).Pointer()
}
func lockedMapPointer(d *data) uintptr {
d.RLock()
defer d.RUnlock()
return mapPointer(d.Data)
}
func requireDataEmpty(t *testing.T, d *data, msg string) {
t.Helper()
d.RLock()
defer d.RUnlock()
require.Empty(t, d.Data, msg)
}
func TestData_ResetPreservesAllocation(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
t.Cleanup(func() {
d.Reset()
dataPool.Put(d)
})
originalPtr := lockedMapPointer(d)
d.Set("key1", "value1")
d.Set("key2", "value2")
require.Equal(t, originalPtr, lockedMapPointer(d), "Expected map pointer to stay constant after writes")
d.Reset()
requireDataEmpty(t, d, "Expected data map to be empty after reset")
require.Equal(t, originalPtr, lockedMapPointer(d), "Expected reset to preserve underlying map")
d.Set("key3", "value3")
require.Nil(t, d.Get("key1"), "Expected cleared key not to leak after reset")
require.Equal(t, originalPtr, lockedMapPointer(d), "Expected map pointer to remain stable after further writes")
}
func TestData_PoolReuseDoesNotLeakEntries(t *testing.T) {
t.Parallel()
acquired := make([]*data, 0, 6)
t.Cleanup(func() {
for _, item := range acquired {
item.Reset()
dataPool.Put(item)
}
})
acquireWithCleanup := func() *data {
d := acquireData()
acquired = append(acquired, d)
return d
}
first := acquireWithCleanup()
first.Set("key1", "value1")
first.Set("key2", "value2")
first.Reset()
originalPtr := lockedMapPointer(first)
dataPool.Put(first)
var reused *data
for i := 0; i < 5; i++ {
candidate := acquireWithCleanup()
if lockedMapPointer(candidate) == originalPtr {
reused = candidate
break
}
requireDataEmpty(t, candidate, "Expected pooled data to be empty when new instance is returned")
require.Nil(t, candidate.Get("key2"), "Expected no leakage of prior entries on alternate pooled instance")
}
if reused == nil {
t.Skip("sync.Pool returned a different instance; reuse cannot be asserted")
return
}
require.Equal(t, originalPtr, lockedMapPointer(reused), "Expected pooled data to reuse cleared map")
requireDataEmpty(t, reused, "Expected pooled data to be empty after reuse")
require.Nil(t, reused.Get("key2"), "Expected no leakage of prior entries on reuse")
reused.Set("key4", "value4")
require.Equal(t, "value4", reused.Get("key4"), "Expected pooled map to accept new values")
}
func TestData_Delete(t *testing.T) {
t.Parallel()
// Test case: Delete existing key
t.Run("Delete existing key", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
defer d.Reset()
d.Set("key1", "value1")
d.Delete("key1")
value := d.Get("key1")
require.Nil(t, value, "Expected nil for deleted key")
})
// Test case: Delete nonexistent key
t.Run("Delete nonexistent key", func(t *testing.T) {
t.Parallel()
d := acquireData()
d.Reset() // Ensure clean state from pool
defer dataPool.Put(d)
defer d.Reset()
d.Delete("nonexistent-key")
// No assertion needed, just ensure no panic or error
})
}
================================================
FILE: middleware/session/middleware.go
================================================
// Package session provides session management middleware for Fiber.
// This middleware handles user sessions, including storing session data in the store.
package session
import (
"errors"
"sync"
"github.com/gofiber/fiber/v3"
)
// Middleware holds session data and configuration.
type Middleware struct {
Session *Session
ctx fiber.Ctx
config Config
mu sync.RWMutex
destroyed bool
}
// Context key for session middleware lookup.
type middlewareKey int
const (
// middlewareContextKey is the key used to store the *Middleware in the context locals.
middlewareContextKey middlewareKey = iota
)
var (
// ErrTypeAssertionFailed occurs when a type assertion fails.
ErrTypeAssertionFailed = errors.New("failed to type-assert to *Middleware")
// Pool for reusing middleware instances.
middlewarePool = &sync.Pool{
New: func() any {
return &Middleware{}
},
}
)
// New initializes session middleware with optional configuration.
//
// Parameters:
// - config: Variadic parameter to override default config.
//
// Returns:
// - fiber.Handler: The Fiber handler for the session middleware.
//
// Usage:
//
// app.Use(session.New())
//
// Usage:
//
// app.Use(session.New())
func New(config ...Config) fiber.Handler {
if len(config) > 0 {
handler, _ := NewWithStore(config[0])
return handler
}
handler, _ := NewWithStore()
return handler
}
// NewWithStore creates session middleware with an optional custom store.
//
// Parameters:
// - config: Variadic parameter to override default config.
//
// Returns:
// - fiber.Handler: The Fiber handler for the session middleware.
// - *Store: The session store.
//
// Usage:
//
// handler, store := session.NewWithStore()
func NewWithStore(config ...Config) (fiber.Handler, *Store) {
cfg := configDefault(config...)
if cfg.Store == nil {
cfg.Store = NewStore(cfg)
}
handler := func(c fiber.Ctx) error {
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Acquire session middleware
m := acquireMiddleware()
m.initialize(c, &cfg)
stackErr := c.Next()
m.mu.RLock()
destroyed := m.destroyed
m.mu.RUnlock()
if !destroyed {
m.saveSession()
}
releaseMiddleware(m)
return stackErr
}
return handler, cfg.Store
}
// initialize sets up middleware for the request.
func (m *Middleware) initialize(c fiber.Ctx, cfg *Config) {
m.mu.Lock()
defer m.mu.Unlock()
session, err := cfg.Store.getSession(c)
if err != nil {
panic(err) // handle or log this error appropriately in production
}
m.config = *cfg
m.Session = session
m.ctx = c
fiber.StoreInContext(c, middlewareContextKey, m)
}
// saveSession handles session saving and error management after the response.
func (m *Middleware) saveSession() {
if err := m.Session.saveSession(); err != nil {
if m.config.ErrorHandler != nil {
m.config.ErrorHandler(m.ctx, err)
} else {
DefaultErrorHandler(m.ctx, err)
}
}
releaseSession(m.Session)
}
// acquireMiddleware retrieves a middleware instance from the pool.
func acquireMiddleware() *Middleware {
m, ok := middlewarePool.Get().(*Middleware)
if !ok {
panic(ErrTypeAssertionFailed.Error())
}
return m
}
// releaseMiddleware resets and returns middleware to the pool.
//
// Parameters:
// - m: The middleware object to release.
//
// Usage:
//
// releaseMiddleware(m)
func releaseMiddleware(m *Middleware) {
m.mu.Lock()
m.config = Config{}
m.Session = nil
m.ctx = nil
m.destroyed = false
m.mu.Unlock()
middlewarePool.Put(m)
}
// FromContext returns the Middleware from the Fiber context.
// It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context.
//
// Parameters:
// - c: The Fiber context.
//
// Returns:
// - *Middleware: The middleware object if found; otherwise, nil.
//
// Usage:
//
// m := session.FromContext(c)
func FromContext(ctx any) *Middleware {
if m, ok := fiber.ValueFromContext[*Middleware](ctx, middlewareContextKey); ok {
return m
}
return nil
}
// Set sets a key-value pair in the session.
//
// Parameters:
// - key: The key to set.
// - value: The value to set.
//
// Usage:
//
// m.Set("key", "value")
func (m *Middleware) Set(key, value any) {
m.mu.Lock()
defer m.mu.Unlock()
m.Session.Set(key, value)
}
// Get retrieves a value from the session by key.
//
// Parameters:
// - key: The key to retrieve.
//
// Returns:
// - any: The value associated with the key.
//
// Usage:
//
// value := m.Get("key")
func (m *Middleware) Get(key any) any {
m.mu.RLock()
defer m.mu.RUnlock()
return m.Session.Get(key)
}
// Delete removes a key-value pair from the session.
//
// Parameters:
// - key: The key to delete.
//
// Usage:
//
// m.Delete("key")
func (m *Middleware) Delete(key any) {
m.mu.Lock()
defer m.mu.Unlock()
m.Session.Delete(key)
}
// Keys returns all keys in the current session.
//
// Returns:
// - []any: A slice of all keys in the session.
//
// Usage:
//
// keys := m.Keys()
func (m *Middleware) Keys() []any {
m.mu.RLock()
defer m.mu.RUnlock()
return m.Session.Keys()
}
// Destroy destroys the session.
//
// Returns:
// - error: An error if the destruction fails.
//
// Usage:
//
// err := m.Destroy()
func (m *Middleware) Destroy() error {
m.mu.Lock()
defer m.mu.Unlock()
err := m.Session.Destroy()
m.destroyed = true
return err
}
// Fresh checks if the session is fresh.
//
// Returns:
// - bool: True if the session is fresh; otherwise, false.
//
// Usage:
//
// isFresh := m.Fresh()
func (m *Middleware) Fresh() bool {
return m.Session.Fresh()
}
// ID returns the session ID.
//
// Returns:
// - string: The session ID.
//
// Usage:
//
// id := m.ID()
func (m *Middleware) ID() string {
return m.Session.ID()
}
// Reset resets the session.
//
// Returns:
// - error: An error if the reset fails.
//
// Usage:
//
// err := m.Reset()
func (m *Middleware) Reset() error {
m.mu.Lock()
defer m.mu.Unlock()
return m.Session.Reset()
}
// Regenerate generates a new session ID while preserving session data.
//
// This method is commonly used after authentication to prevent session fixation attacks.
// Unlike Reset(), this method preserves all existing session data.
//
// Returns:
// - error: An error if the regeneration fails.
//
// Usage:
//
// err := m.Regenerate()
func (m *Middleware) Regenerate() error {
m.mu.Lock()
defer m.mu.Unlock()
return m.Session.Regenerate()
}
// Store returns the session store.
//
// Returns:
// - *Store: The session store.
//
// Usage:
//
// store := m.Store()
func (m *Middleware) Store() *Store {
return m.config.Store
}
================================================
FILE: middleware/session/middleware_test.go
================================================
package session
import (
"fmt"
"net/http"
"net/http/httptest"
"sort"
"strings"
"sync"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
"github.com/gofiber/utils/v2"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func Test_Session_Middleware(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/get", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
value, ok := sess.Get("key").(string)
if !ok {
return c.Status(fiber.StatusNotFound).SendString("key not found")
}
return c.SendString("value=" + value)
})
app.Post("/set", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
// get a value from the body
value := c.FormValue("value")
sess.Set("key", value)
return c.SendStatus(fiber.StatusOK)
})
app.Post("/delete", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
sess.Delete("key")
return c.SendStatus(fiber.StatusOK)
})
app.Post("/reset", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
// Set a value to ensure it is cleared after reset
sess.Set("key", "value")
if err := sess.Reset(); err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
// Ensure value is cleared
value, ok := sess.Get("key").(string)
if ok || value != "" {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
app.Post("/regenerate", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
// Set a value to ensure it is preserved after regeneration
sess.Set("key", "value")
// Regenerate the session ID
if err := sess.Regenerate(); err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
// Ensure the session ID has changed but session data is preserved
newID := sess.ID()
if newID == "" {
return c.SendStatus(fiber.StatusInternalServerError)
}
// Check if the session data is still accessible
value, ok := sess.Get("key").(string)
if !ok || value != "value" {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
app.Post("/destroy", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
if err := sess.Destroy(); err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
app.Post("/fresh", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
// Reset the session to make it fresh
if err := sess.Reset(); err != nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
if sess.Fresh() {
return c.SendStatus(fiber.StatusOK)
}
return c.SendStatus(fiber.StatusInternalServerError)
})
app.Post("/keys", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
// get a value from the body
value := c.FormValue("keys")
for rawKey := range strings.SplitSeq(value, ",") {
key := utils.TrimSpace(rawKey)
if key == "" {
continue
}
// Set each key in the session
sess.Set(key, "value_"+key)
}
return c.SendStatus(fiber.StatusOK)
})
app.Get("/keys", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
keys := sess.Keys()
if len(keys) == 0 {
return c.SendStatus(fiber.StatusNotFound)
}
// Keys may be of any type, so convert to string for display
strKeys := []string{}
for _, key := range keys {
strKeys = append(strKeys, fmt.Sprintf("%v", key))
}
return c.SendString("keys=" + strings.Join(strKeys, ","))
})
// Test GET, SET, DELETE, RESET, REGENERATE, DESTROY by sending requests to the respective routes
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/get")
h := app.Handler()
h(ctx)
require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode())
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, token, "Expected Set-Cookie header to be present")
tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2)
require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token")
token = tokenParts[1]
require.Equal(t, "key not found", string(ctx.Response.Body()))
// Test POST /set
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/set")
ctx.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Set the Content-Type
ctx.Request.SetBodyString("value=hello")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test GET /get to check if the value was set
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/get")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
require.Equal(t, "value=hello", string(ctx.Response.Body()))
// Test POST /delete to delete the value
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/delete")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test GET /get to check if the value was deleted
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/get")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode())
require.Equal(t, "key not found", string(ctx.Response.Body()))
// Test POST /reset to reset the session
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/reset")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// verify we have a new session token
newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present")
newTokenParts := strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2)
require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token")
newToken = newTokenParts[1]
require.NotEqual(t, token, newToken)
token = newToken
// Test POST /regenerate to regenerate the session ID
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/regenerate")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// verify we have a new session token
newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present")
newTokenParts = strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2)
require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token")
newToken = newTokenParts[1]
require.NotEqual(t, token, newToken)
token = newToken
// Test POST /destroy to destroy the session
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/destroy")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Verify the session cookie has expired
setCookieHeader := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.Contains(t, setCookieHeader, "max-age=0")
// Sleep so that the session expires
time.Sleep(1 * time.Second)
// Test GET /get to check if the session was destroyed
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/get")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode())
// check that we have a new session token
newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present")
parts := strings.Split(newToken, ";")
require.Greater(t, len(parts), 1)
valueParts := strings.Split(parts[0], "=")
require.Greater(t, len(valueParts), 1)
newToken = valueParts[1]
require.NotEqual(t, token, newToken)
token = newToken
// Test POST /fresh to check if the session is fresh
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/fresh")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// check that we have a new session token
newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present")
newTokenParts = strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2)
require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token")
newToken = newTokenParts[1]
require.NotEqual(t, token, newToken)
token = newToken
// Test POST /keys to set multiple keys
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.SetRequestURI("/keys")
ctx.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Set the Content-Type
ctx.Request.SetBodyString("keys=key1,key2")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test GET /keys to check if the session has the keys
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.SetRequestURI("/keys")
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
body := string(ctx.Response.Body())
require.True(t, strings.HasPrefix(body, "keys="))
parts = strings.Split(strings.TrimPrefix(body, "keys="), ",")
require.Len(t, parts, 2, "Expected two keys in the session")
sort.Strings(parts)
require.Equal(t, []string{"key1", "key2"}, parts)
}
func Test_Session_NewWithStore(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
sess := FromContext(c)
id := sess.ID()
return c.SendString("value=" + id)
})
app.Post("/", func(c fiber.Ctx) error {
sess := FromContext(c)
id := sess.ID()
c.Cookie(&fiber.Cookie{
Name: "session_id",
Value: id,
})
return nil
})
h := app.Handler()
// Test GET request without cookie
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Get session cookie
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, token, "Expected Set-Cookie header to be present")
tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2)
require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token")
token = tokenParts[1]
require.Equal(t, "value="+token, string(ctx.Response.Body()))
// Test GET request with cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
require.Equal(t, "value="+token, string(ctx.Response.Body()))
}
func Test_Session_FromSession(t *testing.T) {
t.Parallel()
app := fiber.New()
sess := FromContext(app.AcquireCtx(&fasthttp.RequestCtx{}))
require.Nil(t, sess)
app.Use(New())
}
func Test_Session_FromContext_Types(t *testing.T) {
t.Parallel()
app := fiber.New(fiber.Config{PassLocalsToContext: true})
app.Use(New())
app.Get("/", func(c fiber.Ctx) error {
require.NotNil(t, FromContext(c))
customCtx, ok := c.(fiber.CustomCtx)
require.True(t, ok)
require.NotNil(t, FromContext(customCtx))
require.NotNil(t, FromContext(c.RequestCtx()))
require.NotNil(t, FromContext(c.Context()))
return c.SendStatus(fiber.StatusOK)
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
func Test_Session_WithConfig(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{
Next: func(c fiber.Ctx) bool {
return c.Get("key") == "value"
},
IdleTimeout: 1 * time.Second,
Extractor: extractors.FromCookie("session_id_test"),
KeyGenerator: func() string {
return "test"
},
}))
app.Get("/", func(c fiber.Ctx) error {
sess := FromContext(c)
id := sess.ID()
return c.SendString("value=" + id)
})
app.Get("/isFresh", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess.Fresh() {
return c.SendStatus(fiber.StatusOK)
}
return c.SendStatus(fiber.StatusInternalServerError)
})
app.Post("/", func(c fiber.Ctx) error {
sess := FromContext(c)
id := sess.ID()
c.Cookie(&fiber.Cookie{
Name: "session_id_test",
Value: id,
})
return nil
})
h := app.Handler()
// Test GET request without cookie
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Get session cookie
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, token, "Expected Set-Cookie header to be present")
tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2)
require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token")
token = tokenParts[1]
require.Equal(t, "value="+token, string(ctx.Response.Body()))
// Test GET request with cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("session_id_test", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
require.Equal(t, "value="+token, string(ctx.Response.Body()))
// Test POST request with cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie("session_id_test", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test POST request without cookie
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test POST request with wrong key
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie("session_id", token)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Test POST request with wrong value
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.SetCookie("session_id_test", "wrong")
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Check idle timeout not expired
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("session_id_test", token)
ctx.Request.SetRequestURI("/isFresh")
h(ctx)
require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode())
// Test idle timeout
time.Sleep(1200 * time.Millisecond)
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.SetCookie("session_id_test", token)
ctx.Request.SetRequestURI("/isFresh")
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
}
func Test_Session_Next(t *testing.T) {
t.Parallel()
var (
doNext bool
muNext sync.RWMutex
)
app := fiber.New()
app.Use(New(Config{
Next: func(_ fiber.Ctx) bool {
muNext.RLock()
defer muNext.RUnlock()
return doNext
},
}))
app.Get("/", func(c fiber.Ctx) error {
sess := FromContext(c)
if sess == nil {
return c.SendStatus(fiber.StatusInternalServerError)
}
id := sess.ID()
return c.SendString("value=" + id)
})
h := app.Handler()
// Test with Next returning false
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
// Get session cookie
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
require.NotEmpty(t, token, "Expected Set-Cookie header to be present")
tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2)
require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token")
token = tokenParts[1]
require.Equal(t, "value="+token, string(ctx.Response.Body()))
// Test with Next returning true
muNext.Lock()
doNext = true
muNext.Unlock()
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode())
}
func Test_Session_Middleware_Store(t *testing.T) {
t.Parallel()
app := fiber.New()
handler, sessionStore := NewWithStore()
app.Use(handler)
app.Get("/", func(c fiber.Ctx) error {
sess := FromContext(c)
st := sess.Store()
if st != sessionStore {
return c.SendStatus(fiber.StatusInternalServerError)
}
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
// Test GET request
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode())
}
================================================
FILE: middleware/session/session.go
================================================
package session
import (
"bytes"
"context"
"encoding/gob"
"fmt"
"sync"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
// Session represents a user session.
type Session struct {
ctx fiber.Ctx // fiber context
config *Store // store configuration
data *data // key value data
id string // session id
idleTimeout time.Duration // idleTimeout of this session
mu sync.RWMutex // Mutex to protect non-data fields
fresh bool // if new session
}
type absExpirationKeyType int
const (
// sessionIDContextKey is the key used to store the session ID in the context locals.
absExpirationKey absExpirationKeyType = iota
)
// Session pool for reusing byte buffers.
var byteBufferPool = sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
var sessionPool = sync.Pool{
New: func() any {
return &Session{}
},
}
// acquireSession returns a new Session from the pool.
//
// Returns:
// - *Session: The session object.
//
// Usage:
//
// s := acquireSession()
func acquireSession() *Session {
s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
if s.data == nil {
s.data = acquireData()
}
s.fresh = true
return s
}
// Release releases the session back to the pool.
//
// This function should be called after the session is no longer needed.
// This function is used to reduce the number of allocations and
// to improve the performance of the session store.
//
// The session should not be used after calling this function.
//
// Important: The Release function should only be used when accessing the session directly,
// for example, when you have called func (s *Session) Get(ctx) to get the session.
// It should not be used when using the session with a *Middleware handler in the request
// call stack, as the middleware will still need to access the session.
//
// Usage:
//
// sess := session.Get(ctx)
// defer sess.Release()
func (s *Session) Release() {
if s == nil {
return
}
releaseSession(s)
}
func releaseSession(s *Session) {
s.mu.Lock()
s.id = ""
s.idleTimeout = 0
s.ctx = nil
s.config = nil
if s.data != nil {
s.data.Reset()
}
s.mu.Unlock()
sessionPool.Put(s)
}
// Fresh returns whether the session is new
//
// Returns:
// - bool: True if the session is fresh; otherwise, false.
//
// Usage:
//
// isFresh := s.Fresh()
func (s *Session) Fresh() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.fresh
}
// ID returns the session ID
//
// Returns:
// - string: The session ID.
//
// Usage:
//
// id := s.ID()
func (s *Session) ID() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.id
}
// Get returns the value associated with the given key.
//
// Parameters:
// - key: The key to retrieve.
//
// Returns:
// - any: The value associated with the key.
//
// Usage:
//
// value := s.Get("key")
func (s *Session) Get(key any) any {
if s.data == nil {
return nil
}
return s.data.Get(key)
}
// Set updates or creates a new key-value pair in the session.
//
// Parameters:
// - key: The key to set.
// - val: The value to set.
//
// Usage:
//
// s.Set("key", "value")
func (s *Session) Set(key, val any) {
if s.data == nil {
return
}
s.data.Set(key, val)
}
// Delete removes the key-value pair from the session.
//
// Parameters:
// - key: The key to delete.
//
// Usage:
//
// s.Delete("key")
func (s *Session) Delete(key any) {
if s.data == nil {
return
}
s.data.Delete(key)
}
// Destroy deletes the session from storage and expires the session cookie.
//
// Returns:
// - error: An error if the destruction fails.
//
// Usage:
//
// err := s.Destroy()
func (s *Session) Destroy() error {
if s.data == nil {
return nil
}
// Reset local data
s.data.Reset()
s.mu.RLock()
defer s.mu.RUnlock()
// Use external Storage if exist
var ctx context.Context = s.ctx
if ctx == nil {
ctx = context.Background()
}
if err := s.config.Storage.DeleteWithContext(ctx, s.id); err != nil {
return err
}
// Expire session
s.delSession()
return nil
}
// Regenerate generates a new session id and deletes the old one from storage.
//
// Returns:
// - error: An error if the regeneration fails.
//
// Usage:
//
// err := s.Regenerate()
func (s *Session) Regenerate() error {
s.mu.Lock()
defer s.mu.Unlock()
// Delete old id from storage
var ctx context.Context = s.ctx
if ctx == nil {
ctx = context.Background()
}
if err := s.config.Storage.DeleteWithContext(ctx, s.id); err != nil {
return err
}
// Generate a new session, and set session.fresh to true
s.refresh()
return nil
}
// Reset generates a new session id, deletes the old one from storage, and resets the associated data.
//
// Returns:
// - error: An error if the reset fails.
//
// Usage:
//
// err := s.Reset()
func (s *Session) Reset() error {
// Reset local data
if s.data != nil {
s.data.Reset()
}
s.mu.Lock()
defer s.mu.Unlock()
// Reset expiration
s.idleTimeout = 0
// Delete old id from storage
var ctx context.Context = s.ctx
if ctx == nil {
ctx = context.Background()
}
if err := s.config.Storage.DeleteWithContext(ctx, s.id); err != nil {
return err
}
// Expire session
s.delSession()
// Generate a new session, and set session.fresh to true
s.refresh()
return nil
}
// refresh generates a new session, and sets session.fresh to be true.
func (s *Session) refresh() {
s.id = s.config.KeyGenerator()
s.fresh = true
}
// Save saves the session data and updates the cookie
//
// Note: If the session is being used in the handler, calling Save will have
// no effect and the session will automatically be saved when the handler returns.
//
// Returns:
// - error: An error if the save operation fails.
//
// Usage:
//
// err := s.Save()
func (s *Session) Save() error {
if s.ctx == nil {
return s.saveSession()
}
// If the session is being used in the handler, it should not be saved
if m, ok := s.ctx.Locals(middlewareContextKey).(*Middleware); ok {
if m.Session == s {
// Session is in use, so we do nothing and return
return nil
}
}
return s.saveSession()
}
// saveSession encodes session data to saves it to storage.
func (s *Session) saveSession() error {
if s.data == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
// Set idleTimeout if not already set
if s.idleTimeout <= 0 {
s.idleTimeout = s.config.IdleTimeout
}
// Update client cookie
s.setSession()
// Encode session data
s.data.RLock()
encodedBytes, err := s.encodeSessionData()
s.data.RUnlock()
if err != nil {
return fmt.Errorf("failed to encode data: %w", err)
}
// Pass copied bytes with session id to provider
var ctx context.Context = s.ctx
if ctx == nil {
ctx = context.Background()
}
return s.config.Storage.SetWithContext(ctx, s.id, encodedBytes, s.idleTimeout)
}
// Keys retrieves all keys in the current session.
//
// Returns:
// - []any: A slice of all keys in the session.
//
// Usage:
//
// keys := s.Keys()
func (s *Session) Keys() []any {
if s.data == nil {
return []any{}
}
return s.data.Keys()
}
// SetIdleTimeout used when saving the session on the next call to `Save()`.
//
// Parameters:
// - idleTimeout: The duration for the idle timeout.
//
// Usage:
//
// s.SetIdleTimeout(time.Hour)
func (s *Session) SetIdleTimeout(idleTimeout time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
s.idleTimeout = idleTimeout
}
// getExtractorInfo returns all cookie and header extractors from the chain
func (s *Session) getExtractorInfo() []extractors.Extractor {
if s.config == nil {
return []extractors.Extractor{{Source: extractors.SourceCookie, Key: "session_id"}} // Safe default
}
extractor := s.config.Extractor
var relevantExtractors []extractors.Extractor
// If it's a chained extractor, collect all cookie/header extractors
if len(extractor.Chain) > 0 {
for _, chainExtractor := range extractor.Chain {
if chainExtractor.Source == extractors.SourceCookie || chainExtractor.Source == extractors.SourceHeader {
relevantExtractors = append(relevantExtractors, chainExtractor)
}
}
} else if extractor.Source == extractors.SourceCookie || extractor.Source == extractors.SourceHeader {
// Single extractor - only include if it's cookie or header
relevantExtractors = append(relevantExtractors, extractor)
}
// If no cookie/header extractors found and the config has a store but no explicit cookie/header extractors,
// we should not default to cookie. This allows for read-only configurations (e.g., query/param/form/custom).
// Only add default cookie extractor if we have no extractors at all (nil config case is handled above)
return relevantExtractors
}
func (s *Session) setSession() {
if s.ctx == nil {
return
}
// Get all relevant extractors
relevantExtractors := s.getExtractorInfo()
// Set session ID for each extractor type
for _, ext := range relevantExtractors {
switch ext.Source {
case extractors.SourceHeader:
s.ctx.Response().Header.SetBytesV(ext.Key, utils.UnsafeBytes(s.id))
case extractors.SourceCookie:
fcookie := fasthttp.AcquireCookie()
fcookie.SetKey(ext.Key)
fcookie.SetValue(s.id)
fcookie.SetPath(s.config.CookiePath)
fcookie.SetDomain(s.config.CookieDomain)
// Cookies are also session cookies if they do not specify the Expires or Max-Age attribute.
// refer: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
if !s.config.CookieSessionOnly {
fcookie.SetMaxAge(int(s.idleTimeout.Seconds()))
fcookie.SetExpire(time.Now().Add(s.idleTimeout))
}
s.setCookieAttributes(fcookie)
s.ctx.Response().Header.SetCookie(fcookie)
fasthttp.ReleaseCookie(fcookie)
default:
// For non-cookie/header sources, do nothing (read-only)
}
}
}
func (s *Session) delSession() {
if s.ctx == nil {
return
}
// Get all relevant extractors
relevantExtractors := s.getExtractorInfo()
// Delete session ID for each extractor type
for _, ext := range relevantExtractors {
switch ext.Source {
case extractors.SourceHeader:
s.ctx.Request().Header.Del(ext.Key)
s.ctx.Response().Header.Del(ext.Key)
case extractors.SourceCookie:
s.ctx.Request().Header.DelCookie(ext.Key)
s.ctx.Response().Header.DelCookie(ext.Key)
fcookie := fasthttp.AcquireCookie()
fcookie.SetKey(ext.Key)
fcookie.SetPath(s.config.CookiePath)
fcookie.SetDomain(s.config.CookieDomain)
fcookie.SetMaxAge(-1)
fcookie.SetExpire(time.Now().Add(-1 * time.Minute))
s.setCookieAttributes(fcookie)
s.ctx.Response().Header.SetCookie(fcookie)
fasthttp.ReleaseCookie(fcookie)
default:
// For non-cookie/header sources, do nothing (read-only)
}
}
}
// setCookieAttributes sets the cookie attributes based on the session config.
func (s *Session) setCookieAttributes(fcookie *fasthttp.Cookie) {
// Set SameSite attribute
switch {
case utils.EqualFold(s.config.CookieSameSite, fiber.CookieSameSiteStrictMode):
fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case utils.EqualFold(s.config.CookieSameSite, fiber.CookieSameSiteNoneMode):
fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
default:
fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
}
// The Secure attribute is required for SameSite=None
if fcookie.SameSite() == fasthttp.CookieSameSiteNoneMode {
fcookie.SetSecure(true)
} else {
fcookie.SetSecure(s.config.CookieSecure)
}
fcookie.SetHTTPOnly(s.config.CookieHTTPOnly)
}
// decodeSessionData decodes session data from raw bytes
//
// Parameters:
// - rawData: The raw byte data to decode.
//
// Returns:
// - error: An error if the decoding fails.
//
// Usage:
//
// err := s.decodeSessionData(rawData)
func (s *Session) decodeSessionData(rawData []byte) error {
byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
defer byteBufferPool.Put(byteBuffer)
defer byteBuffer.Reset()
_, _ = byteBuffer.Write(rawData)
decCache := gob.NewDecoder(byteBuffer)
if err := decCache.Decode(&s.data.Data); err != nil {
return fmt.Errorf("failed to decode session data: %w", err)
}
return nil
}
// encodeSessionData encodes session data to raw bytes
//
// Parameters:
// - rawData: The raw byte data to encode.
//
// Returns:
// - error: An error if the encoding fails.
//
// Usage:
//
// err := s.encodeSessionData(rawData)
func (s *Session) encodeSessionData() ([]byte, error) {
byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
defer byteBufferPool.Put(byteBuffer)
defer byteBuffer.Reset()
encCache := gob.NewEncoder(byteBuffer)
if err := encCache.Encode(&s.data.Data); err != nil {
return nil, fmt.Errorf("failed to encode session data: %w", err)
}
// Copy the bytes
// Copy the data in buffer
encodedBytes := make([]byte, byteBuffer.Len())
copy(encodedBytes, byteBuffer.Bytes())
return encodedBytes, nil
}
// absExpiration returns the session absolute expiration time or a zero time if not set.
//
// Returns:
// - time.Time: The session absolute expiration time. Zero time if not set.
//
// Usage:
//
// expiration := s.absExpiration()
func (s *Session) absExpiration() time.Time {
absExpiration, ok := s.Get(absExpirationKey).(time.Time)
if ok {
return absExpiration
}
return time.Time{}
}
// isAbsExpired returns true if the session is expired.
//
// If the session has an absolute expiration time set, this function will return true if the
// current time is after the absolute expiration time.
//
// Returns:
// - bool: True if the session is expired; otherwise, false.
func (s *Session) isAbsExpired() bool {
absExpiration := s.absExpiration()
return !absExpiration.IsZero() && time.Now().After(absExpiration)
}
// setAbsExpiration sets the absolute session expiration time.
//
// Parameters:
// - expiration: The session expiration time.
//
// Usage:
//
// s.setAbsExpiration(time.Now().Add(time.Hour))
func (s *Session) setAbsExpiration(absExpiration time.Time) {
s.Set(absExpirationKey, absExpiration)
}
================================================
FILE: middleware/session/session_test.go
================================================
package session
import (
"errors"
"strings"
"sync"
"testing"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
"github.com/gofiber/fiber/v3/internal/storage/memory"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
// go test -run Test_Session
func Test_Session(t *testing.T) {
t.Parallel()
// session store
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// Get a new session
sess, err := store.Get(ctx)
require.NoError(t, err)
require.True(t, sess.Fresh())
token := sess.ID()
require.NoError(t, sess.Save())
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// set session using default cookie extractor
ctx.Request().Header.SetCookie("session_id", token)
// get session
sess, err = store.Get(ctx)
require.NoError(t, err)
require.False(t, sess.Fresh())
// get keys
keys := sess.Keys()
require.Equal(t, []any{}, keys)
// get value
name := sess.Get("name")
require.Nil(t, name)
// set value
sess.Set("name", "john")
// get value
name = sess.Get("name")
require.Equal(t, "john", name)
keys = sess.Keys()
require.Equal(t, []any{"name"}, keys)
// delete key
sess.Delete("name")
// get value
name = sess.Get("name")
require.Nil(t, name)
// get keys
keys = sess.Keys()
require.Equal(t, []any{}, keys)
// get id
id := sess.ID()
require.Equal(t, token, id)
// save the old session first
err = sess.Save()
require.NoError(t, err)
// release the session
sess.Release()
// release the context
app.ReleaseCtx(ctx)
// requesting entirely new context to prevent falsy tests
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
sess, err = store.Get(ctx)
require.NoError(t, err)
require.True(t, sess.Fresh())
// this id should be randomly generated as session key was deleted
require.Len(t, sess.ID(), 43)
sess.Release()
// when we use the original session for the second time
// the session be should be same if the session is not expired
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// request the server with the old session
ctx.Request().Header.SetCookie("session_id", id)
sess, err = store.Get(ctx)
defer sess.Release()
require.NoError(t, err)
require.False(t, sess.Fresh())
require.Equal(t, sess.id, id)
}
// go test -run Test_Session_Types
func Test_Session_Types(t *testing.T) {
t.Parallel()
// session store
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// set cookie
ctx.Request().Header.SetCookie("session_id", "123")
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
require.True(t, sess.Fresh())
// the session string is no longer be 123
newSessionIDString := sess.ID()
type User struct {
Name string
}
store.RegisterType(User{})
vuser := User{
Name: "John",
}
// set value
var (
vbool = true
vstring = "str"
vint = 13
vint8 int8 = 13
vint16 int16 = 13
vint32 int32 = 13
vint64 int64 = 13
vuint uint = 13
vuint8 uint8 = 13
vuint16 uint16 = 13
vuint32 uint32 = 13
vuint64 uint64 = 13
vuintptr uintptr = 13
vbyte byte = 'k'
vrune = 'k'
vfloat32 float32 = 13
vfloat64 float64 = 13
vcomplex64 complex64 = 13
vcomplex128 complex128 = 13
)
sess.Set("vuser", vuser)
sess.Set("vbool", vbool)
sess.Set("vstring", vstring)
sess.Set("vint", vint)
sess.Set("vint8", vint8)
sess.Set("vint16", vint16)
sess.Set("vint32", vint32)
sess.Set("vint64", vint64)
sess.Set("vuint", vuint)
sess.Set("vuint8", vuint8)
sess.Set("vuint16", vuint16)
sess.Set("vuint32", vuint32)
sess.Set("vuint32", vuint32)
sess.Set("vuint64", vuint64)
sess.Set("vuintptr", vuintptr)
sess.Set("vbyte", vbyte)
sess.Set("vrune", vrune)
sess.Set("vfloat32", vfloat32)
sess.Set("vfloat64", vfloat64)
sess.Set("vcomplex64", vcomplex64)
sess.Set("vcomplex128", vcomplex128)
// save session
err = sess.Save()
require.NoError(t, err)
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie("session_id", newSessionIDString)
// get session
sess, err = store.Get(ctx)
require.NoError(t, err)
require.False(t, sess.Fresh())
// get value
vuserResult, ok := sess.Get("vuser").(User)
require.True(t, ok)
require.Equal(t, vuser, vuserResult)
vboolResult, ok := sess.Get("vbool").(bool)
require.True(t, ok)
require.Equal(t, vbool, vboolResult)
vstringResult, ok := sess.Get("vstring").(string)
require.True(t, ok)
require.Equal(t, vstring, vstringResult)
vintResult, ok := sess.Get("vint").(int)
require.True(t, ok)
require.Equal(t, vint, vintResult)
vint8Result, ok := sess.Get("vint8").(int8)
require.True(t, ok)
require.Equal(t, vint8, vint8Result)
vint16Result, ok := sess.Get("vint16").(int16)
require.True(t, ok)
require.Equal(t, vint16, vint16Result)
vint32Result, ok := sess.Get("vint32").(int32)
require.True(t, ok)
require.Equal(t, vint32, vint32Result)
vint64Result, ok := sess.Get("vint64").(int64)
require.True(t, ok)
require.Equal(t, vint64, vint64Result)
vuintResult, ok := sess.Get("vuint").(uint)
require.True(t, ok)
require.Equal(t, vuint, vuintResult)
vuint8Result, ok := sess.Get("vuint8").(uint8)
require.True(t, ok)
require.Equal(t, vuint8, vuint8Result)
vuint16Result, ok := sess.Get("vuint16").(uint16)
require.True(t, ok)
require.Equal(t, vuint16, vuint16Result)
vuint32Result, ok := sess.Get("vuint32").(uint32)
require.True(t, ok)
require.Equal(t, vuint32, vuint32Result)
vuint64Result, ok := sess.Get("vuint64").(uint64)
require.True(t, ok)
require.Equal(t, vuint64, vuint64Result)
vuintptrResult, ok := sess.Get("vuintptr").(uintptr)
require.True(t, ok)
require.Equal(t, vuintptr, vuintptrResult)
vbyteResult, ok := sess.Get("vbyte").(byte)
require.True(t, ok)
require.Equal(t, vbyte, vbyteResult)
vruneResult, ok := sess.Get("vrune").(rune)
require.True(t, ok)
require.Equal(t, vrune, vruneResult)
vfloat32Result, ok := sess.Get("vfloat32").(float32)
require.True(t, ok)
require.InEpsilon(t, vfloat32, vfloat32Result, 0.001)
vfloat64Result, ok := sess.Get("vfloat64").(float64)
require.True(t, ok)
require.InEpsilon(t, vfloat64, vfloat64Result, 0.001)
vcomplex64Result, ok := sess.Get("vcomplex64").(complex64)
require.True(t, ok)
require.Equal(t, vcomplex64, vcomplex64Result)
vcomplex128Result, ok := sess.Get("vcomplex128").(complex128)
require.True(t, ok)
require.Equal(t, vcomplex128, vcomplex128Result)
sess.Release()
app.ReleaseCtx(ctx)
}
// go test -run Test_Session_Store_Reset
func Test_Session_Store_Reset(t *testing.T) {
t.Parallel()
// session store
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
// make sure its new
require.True(t, sess.Fresh())
// set value & save
sess.Set("hello", "world")
ctx.Request().Header.SetCookie("session_id", sess.ID())
require.NoError(t, sess.Save())
// reset store
require.NoError(t, store.Reset(ctx))
id := sess.ID()
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetCookie("session_id", id)
// make sure the session is recreated
sess, err = store.Get(ctx)
defer sess.Release()
require.NoError(t, err)
require.True(t, sess.Fresh())
require.Nil(t, sess.Get("hello"))
}
func Test_Session_KeyTypes(t *testing.T) {
// Note: This test cannot run in parallel because it registers types
// in the global gob registry via store.RegisterType(), which would
// cause race conditions with other parallel tests.
// session store
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
require.True(t, sess.Fresh())
type Person struct {
Name string
}
type unexportedKey int
// register non-default types
store.RegisterType(Person{})
store.RegisterType(unexportedKey(0))
type unregisteredKeyType int
type unregisteredValueType int
// verify unregistered keys types are not allowed
var (
unregisteredKey unregisteredKeyType
unregisteredValue unregisteredValueType
)
sess.Set(unregisteredKey, "test")
err = sess.Save()
require.Error(t, err)
sess.Delete(unregisteredKey)
err = sess.Save()
require.NoError(t, err)
sess.Set("abc", unregisteredValue)
err = sess.Save()
require.Error(t, err)
sess.Delete("abc")
err = sess.Save()
require.NoError(t, err)
require.NoError(t, sess.Reset())
// Release session before continuing
sess.Release()
// Get a new session after reset
sess, err = store.Get(ctx)
require.NoError(t, err)
require.True(t, sess.Fresh())
var (
kbool = true
kstring = "str"
kint = 13
kint8 int8 = 13
kint16 int16 = 13
kint32 int32 = 13
kint64 int64 = 13
kuint uint = 13
kuint8 uint8 = 13
kuint16 uint16 = 13
kuint32 uint32 = 13
kuint64 uint64 = 13
kuintptr uintptr = 13
kbyte byte = 'k'
krune = 'k'
kfloat32 float32 = 13
kfloat64 float64 = 13
kcomplex64 complex64 = 13
kcomplex128 complex128 = 13
kuser = Person{Name: "John"}
kunexportedKey = unexportedKey(13)
)
var (
vbool = true
vstring = "str"
vint = 13
vint8 int8 = 13
vint16 int16 = 13
vint32 int32 = 13
vint64 int64 = 13
vuint uint = 13
vuint8 uint8 = 13
vuint16 uint16 = 13
vuint32 uint32 = 13
vuint64 uint64 = 13
vuintptr uintptr = 13
vbyte byte = 'k'
vrune = 'k'
vfloat32 float32 = 13
vfloat64 float64 = 13
vcomplex64 complex64 = 13
vcomplex128 complex128 = 13
vuser = Person{Name: "John"}
vunexportedKey = unexportedKey(13)
)
keys := []any{
kbool,
kstring,
kint,
kint8,
kint16,
kint32,
kint64,
kuint,
kuint8,
kuint16,
kuint32,
kuint64,
kuintptr,
kbyte,
krune,
kfloat32,
kfloat64,
kcomplex64,
kcomplex128,
kuser,
kunexportedKey,
}
values := []any{
vbool,
vstring,
vint,
vint8,
vint16,
vint32,
vint64,
vuint,
vuint8,
vuint16,
vuint32,
vuint64,
vuintptr,
vbyte,
vrune,
vfloat32,
vfloat64,
vcomplex64,
vcomplex128,
vuser,
vunexportedKey,
}
// loop test all key value pairs
for i, key := range keys {
sess.Set(key, values[i])
}
id := sess.ID()
ctx.Request().Header.SetCookie("session_id", id)
// save session
err = sess.Save()
require.NoError(t, err)
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetCookie("session_id", id)
// get session
sess, err = store.Get(ctx)
require.NoError(t, err)
defer sess.Release()
require.False(t, sess.Fresh())
// loop test all key value pairs
for i, key := range keys {
// get value
result := sess.Get(key)
require.Equal(t, values[i], result)
}
}
// go test -run Test_Session_Save
func Test_Session_Save(t *testing.T) {
t.Parallel()
t.Run("save to cookie", func(t *testing.T) {
t.Parallel()
// session store
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
// set value
sess.Set("name", "john")
// save session
err = sess.Save()
require.NoError(t, err)
sess.Release()
})
t.Run("save to header", func(t *testing.T) {
t.Parallel()
// session store
store := NewStore(Config{
Extractor: extractors.FromHeader("session_id"),
})
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
// set value
sess.Set("name", "john")
// save session
err = sess.Save()
require.NoError(t, err)
require.Equal(t, sess.ID(), string(ctx.Response().Header.Peek("session_id")))
sess.Release()
})
}
// Test chained extractors to ensure both cookie and header are set when both are present
func Test_Session_ChainedExtractors(t *testing.T) {
t.Parallel()
t.Run("cookie and header chain", func(t *testing.T) {
t.Parallel()
// session store with chained extractors
store := NewStore(Config{
Extractor: extractors.Chain(extractors.FromCookie("session_id"), extractors.FromHeader("x-session-id")),
})
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
// set value
sess.Set("name", "john")
// save session
err = sess.Save()
require.NoError(t, err)
// verify both cookie and header are set
cookie := ctx.Response().Header.PeekCookie("session_id")
require.NotNil(t, cookie)
require.Contains(t, string(cookie), sess.ID())
header := string(ctx.Response().Header.Peek("x-session-id"))
require.Equal(t, sess.ID(), header)
sess.Release()
})
t.Run("header and cookie chain", func(t *testing.T) {
t.Parallel()
// session store with chained extractors (different order)
store := NewStore(Config{
Extractor: extractors.Chain(extractors.FromHeader("x-session-id"), extractors.FromCookie("session_id")),
})
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
// set value
sess.Set("name", "john")
// save session
err = sess.Save()
require.NoError(t, err)
// verify both header and cookie are set
header := string(ctx.Response().Header.Peek("x-session-id"))
require.Equal(t, sess.ID(), header)
cookie := ctx.Response().Header.PeekCookie("session_id")
require.NotNil(t, cookie)
require.Contains(t, string(cookie), sess.ID())
sess.Release()
})
t.Run("only SourceOther extractors - no response setting", func(t *testing.T) {
t.Parallel()
// session store with only query/form extractors
store := NewStore(Config{
Extractor: extractors.Chain(extractors.FromQuery("session_id"), extractors.FromForm("session_id")),
})
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
// set value
sess.Set("name", "john")
// save session
err = sess.Save()
require.NoError(t, err)
// verify no cookie or header is set
cookie := ctx.Response().Header.PeekCookie("session_id")
require.Nil(t, cookie)
header := string(ctx.Response().Header.Peek("session_id"))
require.Empty(t, header)
sess.Release()
})
t.Run("mixed chain with SourceOther", func(t *testing.T) {
t.Parallel()
// session store with mixed extractors including SourceOther
store := NewStore(Config{
Extractor: extractors.Chain(extractors.FromCookie("session_id"), extractors.FromQuery("session_id"), extractors.FromHeader("x-session-id")),
})
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
// set value
sess.Set("name", "john")
// save session
err = sess.Save()
require.NoError(t, err)
// verify both cookie and header are set (query is ignored for response)
cookie := ctx.Response().Header.PeekCookie("session_id")
require.NotNil(t, cookie)
require.Contains(t, string(cookie), sess.ID())
header := string(ctx.Response().Header.Peek("x-session-id"))
require.Equal(t, sess.ID(), header)
sess.Release()
})
}
func Test_Session_Save_IdleTimeout(t *testing.T) {
t.Parallel()
t.Run("save to cookie", func(t *testing.T) {
t.Parallel()
const sessionDuration = 5 * time.Second
// session store
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
// set value
sess.Set("name", "john")
token := sess.ID()
// expire this session in 5 seconds
sess.SetIdleTimeout(sessionDuration)
// save session
err = sess.Save()
require.NoError(t, err)
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// here you need to get the old session yet
ctx.Request().Header.SetCookie("session_id", token)
sess, err = store.Get(ctx)
require.NoError(t, err)
require.Equal(t, token, sess.ID(), "session ID should match before expiration")
name := sess.Get("name")
require.Equal(t, "john", name, "session should contain the saved value before expiration")
// just to make sure the session has been expired
// Add extra buffer time to ensure expiration is processed
time.Sleep(sessionDuration + (100 * time.Millisecond))
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// here you should get a new session
ctx.Request().Header.SetCookie("session_id", token)
sess, err = store.Get(ctx)
defer sess.Release()
require.NoError(t, err)
require.Nil(t, sess.Get("name"))
require.NotEqual(t, sess.ID(), token)
})
}
func Test_Session_Save_AbsoluteTimeout(t *testing.T) {
t.Parallel()
t.Run("save to cookie", func(t *testing.T) {
t.Parallel()
const absoluteTimeout = 2 * time.Second // extra headroom to avoid flakiness under -race
// session store
store := NewStore(Config{
IdleTimeout: absoluteTimeout,
AbsoluteTimeout: absoluteTimeout,
})
// force change to IdleTimeout
store.IdleTimeout = 10 * time.Second
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
// set value
sess.Set("name", "john")
token := sess.ID()
// save session
err = sess.Save()
require.NoError(t, err)
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// here you need to get the old session yet
ctx.Request().Header.SetCookie("session_id", token)
sess, err = store.Get(ctx)
require.NoError(t, err)
require.Equal(t, "john", sess.Get("name"))
// just to make sure the session has been expired
time.Sleep(absoluteTimeout + (200 * time.Millisecond))
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// here you should get a new session
ctx.Request().Header.SetCookie("session_id", token)
sess, err = store.Get(ctx)
require.NoError(t, err)
require.Nil(t, sess.Get("name"))
require.NotEqual(t, sess.ID(), token)
require.True(t, sess.Fresh())
require.IsType(t, time.Time{}, sess.Get(absExpirationKey))
token = sess.ID()
sess.Set("name", "john")
// save session
err = sess.Save()
require.NoError(t, err)
sess.Release()
app.ReleaseCtx(ctx)
// just to make sure the session has been expired
time.Sleep(absoluteTimeout + (200 * time.Millisecond))
// try to get expired session by id
sess, err = store.GetByID(ctx, token)
require.Error(t, err)
require.ErrorIs(t, err, ErrSessionIDNotFoundInStore)
require.Nil(t, sess)
})
}
// go test -run Test_Session_Destroy
func Test_Session_Destroy(t *testing.T) {
t.Parallel()
t.Run("destroy from cookie", func(t *testing.T) {
t.Parallel()
// session store
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
defer sess.Release()
require.NoError(t, err)
sess.Set("name", "fenny")
require.NoError(t, sess.Destroy())
name := sess.Get("name")
require.Nil(t, name)
})
t.Run("destroy from header", func(t *testing.T) {
t.Parallel()
// session store
store := NewStore(Config{
Extractor: extractors.FromHeader("session_id"),
})
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
// set value & save
sess.Set("name", "fenny")
id := sess.ID()
require.NoError(t, sess.Save())
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
ctx.Request().Header.Set("session_id", id)
sess, err = store.Get(ctx)
require.NoError(t, err)
defer sess.Release()
err = sess.Destroy()
require.NoError(t, err)
require.Empty(t, string(ctx.Response().Header.Peek("session_id")))
})
}
// go test -run Test_Session_Custom_Config
func Test_Session_Custom_Config(t *testing.T) {
t.Parallel()
store := NewStore(Config{IdleTimeout: time.Hour, KeyGenerator: func() string { return "very random" }})
require.Equal(t, time.Hour, store.IdleTimeout)
require.Equal(t, "very random", store.KeyGenerator())
store = NewStore(Config{IdleTimeout: 0})
require.Equal(t, ConfigDefault.IdleTimeout, store.IdleTimeout)
}
// go test -run Test_Session_Cookie
func Test_Session_Cookie(t *testing.T) {
t.Parallel()
// session store
store := NewStore()
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
require.NoError(t, sess.Save())
sess.Release()
// cookie should be set on Save ( even if empty data )
cookie := ctx.Response().Header.PeekCookie("session_id")
require.NotNil(t, cookie)
require.Regexp(t, `^session_id=[A-Za-z0-9\-_]{43}; max-age=\d+; path=/; SameSite=Lax$`, string(cookie))
}
// go test -run Test_Session_Cookie_SameSite
func Test_Session_Cookie_SameSite(t *testing.T) {
t.Parallel()
tests := []struct {
expectedInHeader string
name string
sameSite string
initialSecure bool
}{
{
name: "Lax should not force secure",
sameSite: "Lax",
initialSecure: false,
expectedInHeader: "SameSite=Lax",
},
{
name: "Lax with secure should stay secure",
sameSite: "Lax",
initialSecure: true,
expectedInHeader: "SameSite=Lax; secure",
},
{
name: "Strict should not force secure",
sameSite: "Strict",
initialSecure: false,
expectedInHeader: "SameSite=Strict",
},
{
name: "Strict with secure should stay secure",
sameSite: "Strict",
initialSecure: true,
expectedInHeader: "SameSite=Strict; secure",
},
{
name: "None should force secure",
sameSite: "None",
initialSecure: false,
expectedInHeader: "SameSite=None; secure",
},
{
name: "None with secure should stay secure",
sameSite: "None",
initialSecure: true,
expectedInHeader: "SameSite=None; secure",
},
{
name: "Case-insensitive none should force secure",
sameSite: "none",
initialSecure: false,
expectedInHeader: "SameSite=None; secure",
},
{
name: "Case-insensitive strict should not force secure",
sameSite: "strict",
initialSecure: false,
expectedInHeader: "SameSite=Strict",
},
{
name: "Default should be Lax",
sameSite: "invalid",
initialSecure: false,
expectedInHeader: "SameSite=Lax",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// session store
store := NewStore(Config{
CookieSameSite: tc.sameSite,
CookieSecure: tc.initialSecure,
})
// fiber instance
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
defer sess.Release()
// save session to trigger cookie setting
err = sess.Save()
require.NoError(t, err)
// check cookie
cookie := string(ctx.Response().Header.PeekCookie("session_id"))
// The order of attributes in the cookie string is not guaranteed.
// Instead of checking for a single substring, we check for the presence of each part.
parts := strings.SplitSeq(tc.expectedInHeader, "; ")
for part := range parts {
require.Contains(t, cookie, part)
}
// Also check that secure is NOT present when it shouldn't be
if !tc.initialSecure && tc.sameSite != "None" && tc.sameSite != "none" {
require.NotContains(t, cookie, "secure")
}
})
}
}
// go test -run Test_Session_Cookie_In_Response
// Regression: https://github.com/gofiber/fiber/pull/1191
func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) {
store := NewStore()
app := fiber.New()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// get session
sess, err := store.Get(ctx)
require.NoError(t, err)
sess.Set("id", "1")
require.True(t, sess.Fresh())
id := sess.ID()
require.NoError(t, sess.Save())
sess.Release()
sess, err = store.Get(ctx)
require.NoError(t, err)
defer sess.Release()
sess.Set("name", "john")
require.False(t, sess.Fresh()) // Session should not be fresh - it reuses the same ID from context locals
require.Equal(t, id, sess.ID()) // session id should be the same
require.Equal(t, "1", sess.Get("id"))
require.Equal(t, "john", sess.Get("name"))
}
// go test -run Test_Session_Deletes_Single_Key
// Regression: https://github.com/gofiber/fiber/issues/1365
func Test_Session_Deletes_Single_Key(t *testing.T) {
t.Parallel()
store := NewStore()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
sess, err := store.Get(ctx)
require.NoError(t, err)
id := sess.ID()
sess.Set("id", "1")
require.NoError(t, sess.Save())
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie("session_id", id)
sess, err = store.Get(ctx)
require.NoError(t, err)
sess.Delete("id")
require.NoError(t, sess.Save())
sess.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().Header.SetCookie("session_id", id)
sess, err = store.Get(ctx)
defer sess.Release()
require.NoError(t, err)
require.False(t, sess.Fresh())
require.Nil(t, sess.Get("id"))
app.ReleaseCtx(ctx)
}
// go test -run Test_Session_Reset
func Test_Session_Reset(t *testing.T) {
t.Parallel()
// fiber instance
app := fiber.New()
// session store
store := NewStore()
t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) {
t.Parallel()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
// a random session uuid
originalSessionUUIDString := ""
// now the session is in the storage
freshSession, err := store.Get(ctx)
require.NoError(t, err)
originalSessionUUIDString = freshSession.ID()
// set a value
freshSession.Set("name", "fenny")
freshSession.Set("email", "fenny@example.com")
err = freshSession.Save()
require.NoError(t, err)
freshSession.Release()
app.ReleaseCtx(ctx)
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// set cookie
ctx.Request().Header.SetCookie("session_id", originalSessionUUIDString)
// as the session is in the storage, session.fresh should be false
acquiredSession, err := store.Get(ctx)
require.NoError(t, err)
require.False(t, acquiredSession.Fresh())
err = acquiredSession.Reset()
require.NoError(t, err)
require.NotEqual(t, originalSessionUUIDString, acquiredSession.ID())
// acquiredSession.fresh should be true after resetting
require.True(t, acquiredSession.Fresh())
// Check that the session data has been reset
keys := acquiredSession.Keys()
require.Equal(t, []any{}, keys)
// Set a new value for 'name' and check that it's updated
acquiredSession.Set("name", "john")
require.Equal(t, "john", acquiredSession.Get("name"))
require.Nil(t, acquiredSession.Get("email"))
// Save after resetting
err = acquiredSession.Save()
require.NoError(t, err)
acquiredSession.Release()
// Check that the session id is not in the header or cookie anymore
require.Empty(t, string(ctx.Response().Header.Peek("session_id")))
require.Empty(t, string(ctx.Request().Header.Peek("session_id")))
app.ReleaseCtx(ctx)
})
}
// go test -run Test_Session_Regenerate
// Regression: https://github.com/gofiber/fiber/issues/1395
func Test_Session_Regenerate(t *testing.T) {
t.Parallel()
// fiber instance
app := fiber.New()
t.Run("set fresh to be true when regenerating a session", func(t *testing.T) {
t.Parallel()
// session store
store := NewStore()
// a random session uuid
originalSessionUUIDString := ""
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// now the session is in the storage
freshSession, err := store.Get(ctx)
require.NoError(t, err)
originalSessionUUIDString = freshSession.ID()
err = freshSession.Save()
require.NoError(t, err)
freshSession.Release()
// release the context
app.ReleaseCtx(ctx)
// acquire a new context
ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
// set cookie
ctx.Request().Header.SetCookie("session_id", originalSessionUUIDString)
// as the session is in the storage, session.fresh should be false
acquiredSession, err := store.Get(ctx)
require.NoError(t, err)
defer acquiredSession.Release()
require.False(t, acquiredSession.Fresh())
err = acquiredSession.Regenerate()
require.NoError(t, err)
require.NotEqual(t, originalSessionUUIDString, acquiredSession.ID())
// acquiredSession.fresh should be true after regenerating
require.True(t, acquiredSession.Fresh())
// release the context
app.ReleaseCtx(ctx)
})
}
// go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4
func Benchmark_Session(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), NewStore()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie("session_id", "12356789")
b.ReportAllocs()
for b.Loop() {
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
sess.Release()
}
})
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := NewStore(Config{
Storage: memory.New(),
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie("session_id", "12356789")
b.ReportAllocs()
for b.Loop() {
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
sess.Release()
}
})
}
// go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4
func Benchmark_Session_Parallel(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), NewStore()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie("session_id", "12356789")
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
sess.Release()
app.ReleaseCtx(c)
}
})
})
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := NewStore(Config{
Storage: memory.New(),
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie("session_id", "12356789")
sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
sess.Set("john", "doe")
_ = sess.Save() //nolint:errcheck // We're inside a benchmark
sess.Release()
app.ReleaseCtx(c)
}
})
})
}
// go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4
func Benchmark_Session_Asserted(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), NewStore()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie("session_id", "12356789")
b.ReportAllocs()
for b.Loop() {
sess, err := store.Get(c)
require.NoError(b, err)
sess.Set("john", "doe")
err = sess.Save()
require.NoError(b, err)
sess.Release()
}
})
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := NewStore(Config{
Storage: memory.New(),
})
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.SetCookie("session_id", "12356789")
b.ReportAllocs()
for b.Loop() {
sess, err := store.Get(c)
require.NoError(b, err)
sess.Set("john", "doe")
err = sess.Save()
require.NoError(b, err)
sess.Release()
}
})
}
// go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4
func Benchmark_Session_Asserted_Parallel(b *testing.B) {
b.Run("default", func(b *testing.B) {
app, store := fiber.New(), NewStore()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie("session_id", "12356789")
sess, err := store.Get(c)
require.NoError(b, err)
sess.Set("john", "doe")
require.NoError(b, sess.Save())
sess.Release()
app.ReleaseCtx(c)
}
})
})
b.Run("storage", func(b *testing.B) {
app := fiber.New()
store := NewStore(Config{
Storage: memory.New(),
})
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.SetCookie("session_id", "12356789")
sess, err := store.Get(c)
require.NoError(b, err)
sess.Set("john", "doe")
require.NoError(b, sess.Save())
sess.Release()
app.ReleaseCtx(c)
}
})
})
}
// go test -v -race -run Test_Session_Concurrency ./...
func Test_Session_Concurrency(t *testing.T) {
app := fiber.New()
store := NewStore()
var wg sync.WaitGroup
errChan := make(chan error, 10) // Buffered channel to collect errors
const numGoroutines = 10 // Number of concurrent goroutines to test
// Start numGoroutines goroutines
for range numGoroutines {
wg.Go(func() {
localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})
sess, err := store.getSession(localCtx)
if err != nil {
errChan <- err
return
}
// Set a value
sess.Set("name", "john")
// get the session id
id := sess.ID()
// Check if the session is fresh
if !sess.Fresh() {
errChan <- errors.New("session should be fresh")
return
}
// Save the session
if saveErr := sess.Save(); saveErr != nil {
errChan <- saveErr
return
}
// release the session
sess.Release()
// Release the context
app.ReleaseCtx(localCtx)
// Acquire a new context
localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(localCtx)
// Set the session id in the header
localCtx.Request().Header.SetCookie("session_id", id)
// Get the session
sess, err = store.Get(localCtx)
if err != nil {
errChan <- err
return
}
defer sess.Release()
// Get the value
name := sess.Get("name")
if name != "john" {
errChan <- errors.New("name should be john")
return
}
// Get ID from the session
if sess.ID() != id {
errChan <- errors.New("id should be the same")
return
}
// Check if the session is fresh
if sess.Fresh() {
errChan <- errors.New("session should not be fresh")
return
}
// Delete the key
sess.Delete("name")
// Get the value
name = sess.Get("name")
if name != nil {
errChan <- errors.New("name should be nil")
return
}
// Destroy the session
if err := sess.Destroy(); err != nil {
errChan <- err
return
}
})
}
wg.Wait() // Wait for all goroutines to finish
close(errChan) // Close the channel to signal no more errors will be sent
// Check for errors sent to errChan
for err := range errChan {
require.NoError(t, err)
}
}
func Test_Session_StoreGetDecodeSessionDataError(t *testing.T) {
// Initialize a new store with default config
store := NewStore()
// Create a new Fiber app
app := fiber.New()
// Generate a fake session ID
sessionID := uuid.New().String()
// Store invalid session data to simulate decode error
err := store.Storage.Set(sessionID, []byte("invalid data"), 0)
require.NoError(t, err, "Failed to set invalid session data")
// Create a new request context
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
// Set the session ID in cookies
c.Request().Header.SetCookie("session_id", sessionID)
// Attempt to get the session
_, err = store.Get(c)
require.Error(t, err, "Expected error due to invalid session data, but got nil")
// Check that the error message is as expected
require.Contains(t, err.Error(), "failed to decode session data", "Unexpected error message")
// Check that the error is as expected
require.ErrorContains(t, err, "failed to decode session data", "Unexpected error")
// Attempt to get the session by ID
_, err = store.GetByID(c, sessionID)
require.Error(t, err, "Expected error due to invalid session data, but got nil")
// Check that the error message is as expected
require.ErrorContains(t, err, "failed to decode session data", "Unexpected error")
}
// go test -run Test_Session_Fresh_Flag_Bug
// This test verifies the fix for the fresh flag bug where calling getSession()
// multiple times in the same request would incorrectly mark the session as fresh
// when the ID was found in context locals.
func Test_Session_Fresh_Flag_Bug(t *testing.T) {
t.Parallel()
store := NewStore()
app := fiber.New()
// Test Case 1: First call with no session cookie - should be fresh
ctx1 := app.AcquireCtx(&fasthttp.RequestCtx{})
sess1, err := store.Get(ctx1)
require.NoError(t, err)
require.True(t, sess1.Fresh(), "First session should be fresh (no cookie provided)")
sessionID := sess1.ID()
require.NoError(t, sess1.Save())
sess1.Release()
app.ReleaseCtx(ctx1)
// Test Case 2: Second call with session cookie - should NOT be fresh
ctx2 := app.AcquireCtx(&fasthttp.RequestCtx{})
ctx2.Request().Header.SetCookie("session_id", sessionID)
sess2, err := store.Get(ctx2)
require.NoError(t, err)
require.False(t, sess2.Fresh(), "Existing session should not be fresh")
require.Equal(t, sessionID, sess2.ID())
// Test Case 3: Call getSession() again in the same request
// This simulates what happens when CSRF middleware calls store operations
// The session ID is now in context locals from the first getSession() call
sess3, err := store.getSession(ctx2)
require.NoError(t, err)
require.False(t, sess3.Fresh(), "Session should still not be fresh on second getSession() call in same request")
require.Equal(t, sessionID, sess3.ID())
sess2.Release()
sess3.Release()
app.ReleaseCtx(ctx2)
// Test Case 4: Expired session - should generate new ID and be fresh
ctx3 := app.AcquireCtx(&fasthttp.RequestCtx{})
ctx3.Request().Header.SetCookie("session_id", "expired-or-nonexistent-id")
sess4, err := store.Get(ctx3)
require.NoError(t, err)
require.True(t, sess4.Fresh(), "New session (after expired/missing data) should be fresh")
require.NotEqual(t, "expired-or-nonexistent-id", sess4.ID(), "Should have generated a new session ID")
sess4.Release()
app.ReleaseCtx(ctx3)
}
// go test -run Test_Session_CSRF_Scenario
// This test simulates the user-reported issue with CSRF + session middleware
// where a POST without CSRF token would result in a new session_id cookie
func Test_Session_CSRF_Scenario(t *testing.T) {
t.Parallel()
store := NewStore(Config{
IdleTimeout: 2 * time.Second, // Longer timeout to ensure session persists
})
app := fiber.New()
// Simulate: First GET request creates session
ctx1 := app.AcquireCtx(&fasthttp.RequestCtx{})
sess1, err := store.Get(ctx1)
require.NoError(t, err)
require.True(t, sess1.Fresh())
firstSessionID := sess1.ID()
// Store some data (simulating CSRF token storage)
sess1.Set("csrf_token", "token-123")
require.NoError(t, sess1.Save())
sess1.Release()
app.ReleaseCtx(ctx1)
// Small delay to ensure save completes
time.Sleep(10 * time.Millisecond)
// Simulate: POST request with valid session (before expiration)
ctx2 := app.AcquireCtx(&fasthttp.RequestCtx{})
ctx2.Request().Header.SetCookie("session_id", firstSessionID)
sess2, err := store.Get(ctx2)
require.NoError(t, err)
require.False(t, sess2.Fresh(), "Session should not be fresh - it exists")
require.Equal(t, firstSessionID, sess2.ID(), "Session ID should remain the same")
require.Equal(t, "token-123", sess2.Get("csrf_token"))
// Simulate CSRF validation failure (session is accessed but request fails)
// Session should still maintain the same ID
require.Equal(t, firstSessionID, sess2.ID())
sess2.Release()
app.ReleaseCtx(ctx2)
// Wait for session to expire
time.Sleep(2200 * time.Millisecond)
// Simulate: POST request with expired session
// This is the scenario the user reported - session data is gone
ctx3 := app.AcquireCtx(&fasthttp.RequestCtx{})
ctx3.Request().Header.SetCookie("session_id", firstSessionID)
sess3, err := store.Get(ctx3)
require.NoError(t, err)
require.True(t, sess3.Fresh(), "Session should be fresh - old data expired")
require.NotEqual(t, firstSessionID, sess3.ID(), "Should have generated new session ID (expected behavior)")
require.Nil(t, sess3.Get("csrf_token"), "Old session data should be gone")
sess3.Release()
app.ReleaseCtx(ctx3)
}
// go test -run Test_Session_Multiple_GetSession_Calls
// This test ensures that calling getSession() multiple times within the same
// request context doesn't incorrectly mark the session as fresh due to the
// session ID being stored in context locals
func Test_Session_Multiple_GetSession_Calls(t *testing.T) {
t.Parallel()
store := NewStore()
app := fiber.New()
// Create initial session
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
sess1, err := store.Get(ctx)
require.NoError(t, err)
require.True(t, sess1.Fresh())
sessionID := sess1.ID()
sess1.Set("test_key", "test_value")
require.NoError(t, sess1.Save())
sess1.Release()
app.ReleaseCtx(ctx)
// New request with existing session
ctx2 := app.AcquireCtx(&fasthttp.RequestCtx{})
ctx2.Request().Header.SetCookie("session_id", sessionID)
// First getSession() call - loads from storage
sess2, err := store.getSession(ctx2)
require.NoError(t, err)
require.False(t, sess2.Fresh(), "First call: existing session should not be fresh")
require.Equal(t, sessionID, sess2.ID())
require.Equal(t, "test_value", sess2.Get("test_key"))
// Second getSession() call - ID now in context locals
// This is where the bug would manifest before the fix
sess3, err := store.getSession(ctx2)
require.NoError(t, err)
require.False(t, sess3.Fresh(), "Second call: session should STILL not be fresh (bug fix verification)")
require.Equal(t, sessionID, sess3.ID())
require.Equal(t, "test_value", sess3.Get("test_key"))
// Third call to ensure consistency
sess4, err := store.getSession(ctx2)
require.NoError(t, err)
require.False(t, sess4.Fresh(), "Third call: session should remain not fresh")
require.Equal(t, sessionID, sess4.ID())
sess2.Release()
sess3.Release()
sess4.Release()
app.ReleaseCtx(ctx2)
}
================================================
FILE: middleware/session/store.go
================================================
package session
import (
"context"
"encoding/gob"
"errors"
"fmt"
"time"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/internal/storage/memory"
"github.com/gofiber/fiber/v3/log"
)
// ErrEmptySessionID is an error that occurs when the session ID is empty.
var (
ErrEmptySessionID = errors.New("session ID cannot be empty")
ErrSessionAlreadyLoadedByMiddleware = errors.New("session already loaded by middleware")
ErrSessionIDNotFoundInStore = errors.New("session ID not found in session store")
)
// sessionIDKey is the local key type used to store and retrieve the session ID in context.
type sessionIDKey int
const (
// sessionIDContextKey is the key used to store the session ID in the context locals.
sessionIDContextKey sessionIDKey = iota
)
// Store manages session data using the configured storage backend.
type Store struct {
Config
}
// NewStore creates a new session store with the provided configuration.
//
// Parameters:
// - config: Variadic parameter to override default config.
//
// Returns:
// - *Store: The session store.
//
// Usage:
//
// store := session.NewStore()
func NewStore(config ...Config) *Store {
// Set default config
cfg := configDefault(config...)
if cfg.Storage == nil {
cfg.Storage = memory.New()
}
store := &Store{
Config: cfg,
}
if cfg.AbsoluteTimeout > 0 {
store.RegisterType(absExpirationKey)
store.RegisterType(time.Time{})
}
return store
}
// RegisterType registers a custom type for encoding/decoding into any storage provider.
//
// Parameters:
// - i: The custom type to register.
//
// Usage:
//
// store.RegisterType(MyCustomType{})
func (*Store) RegisterType(i any) {
gob.Register(i)
}
// Get will get/create a session.
//
// This function will return an ErrSessionAlreadyLoadedByMiddleware if
// the session is already loaded by the middleware.
//
// Parameters:
// - c: The Fiber context.
//
// Returns:
// - *Session: The session object.
// - error: An error if the session retrieval fails or if the session is already loaded by the middleware.
//
// Usage:
//
// sess, err := store.Get(c)
// if err != nil {
// // handle error
// }
func (s *Store) Get(c fiber.Ctx) (*Session, error) {
// If session is already loaded in the context,
// it should not be loaded again
_, ok := c.Locals(middlewareContextKey).(*Middleware)
if ok {
return nil, ErrSessionAlreadyLoadedByMiddleware
}
return s.getSession(c)
}
// getSession retrieves a session based on the context.
//
// Parameters:
// - c: The Fiber context.
//
// Returns:
// - *Session: The session object.
// - error: An error if the session retrieval fails.
//
// Usage:
//
// sess, err := store.getSession(c)
// if err != nil {
// // handle error
// }
func (s *Store) getSession(c fiber.Ctx) (*Session, error) {
var rawData []byte
var err error
id, ok := c.Locals(sessionIDContextKey).(string)
if !ok {
id = s.getSessionID(c)
}
fresh := false // Session is not fresh initially; only set to true if we generate a new ID
// Attempt to fetch session data if an ID is provided
if id != "" {
rawData, err = s.Storage.GetWithContext(c, id)
if err != nil {
return nil, err
}
if rawData == nil {
// Data not found, prepare to generate a new session
id = ""
}
}
// Generate a new ID if needed
if id == "" {
fresh = true // The session is fresh if a new ID is generated
id = s.KeyGenerator()
c.Locals(sessionIDContextKey, id)
}
// Create session object
sess := acquireSession()
sess.mu.Lock()
sess.ctx = c
sess.config = s
sess.id = id
sess.fresh = fresh
// Decode session data if found
if rawData != nil {
sess.data.Lock()
err := sess.decodeSessionData(rawData)
sess.data.Unlock()
if err != nil {
sess.mu.Unlock()
sess.Release()
return nil, fmt.Errorf("failed to decode session data: %w", err)
}
}
sess.mu.Unlock()
if fresh && s.AbsoluteTimeout > 0 {
sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout))
} else if sess.isAbsExpired() {
if err := sess.Reset(); err != nil {
return nil, fmt.Errorf("failed to reset session: %w", err)
}
sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout))
}
return sess, nil
}
// getSessionID returns the session ID using the configured extractor.
// The extractor is provided by the shared extractors package.
//
// Parameters:
// - c: The Fiber context.
//
// Returns:
// - string: The session ID.
//
// Usage:
//
// id := store.getSessionID(c)
func (s *Store) getSessionID(c fiber.Ctx) string {
sessionID, err := s.Extractor.Extract(c)
if err != nil {
// If extraction fails, return empty string to generate a new session
return ""
}
return sessionID
}
// Reset deletes all sessions from the storage.
//
// Returns:
// - error: An error if the reset operation fails.
//
// Usage:
//
// err := store.Reset()
// if err != nil {
// // handle error
// }
func (s *Store) Reset(ctx context.Context) error {
return s.Storage.ResetWithContext(ctx)
}
// Delete deletes a session by its ID.
//
// Parameters:
// - id: The unique identifier of the session.
//
// Returns:
// - error: An error if the deletion fails or if the session ID is empty.
//
// Usage:
//
// err := store.Delete(id)
// if err != nil {
// // handle error
// }
func (s *Store) Delete(ctx context.Context, id string) error {
if id == "" {
return ErrEmptySessionID
}
return s.Storage.DeleteWithContext(ctx, id)
}
// GetByID retrieves a session by its ID from the storage.
// If the session is not found, it returns nil and an error.
//
// Unlike session middleware methods, this function does not automatically:
//
// - Load the session into the request context.
//
// - Save the session data to the storage or update the client cookie.
//
// Important Notes:
//
// - The session object returned by GetByID does not have a context associated with it.
//
// - When using this method alongside session middleware, there is a potential for collisions,
// so be mindful of interactions between manually retrieved sessions and middleware-managed sessions.
//
// - If you modify a session returned by GetByID, you must call session.Save() to persist the changes.
//
// - When you are done with the session, you should call session.Release() to release the session back to the pool.
//
// Parameters:
// - id: The unique identifier of the session.
//
// Returns:
// - *Session: The session object if found; otherwise, nil.
// - error: An error if the session retrieval fails or if the session ID is empty.
//
// Usage:
//
// sess, err := store.GetByID(id)
// if err != nil {
// // handle error
// }
func (s *Store) GetByID(ctx context.Context, id string) (*Session, error) {
if id == "" {
return nil, ErrEmptySessionID
}
rawData, err := s.Storage.GetWithContext(ctx, id)
if err != nil {
return nil, err
}
if rawData == nil {
return nil, ErrSessionIDNotFoundInStore
}
sess := acquireSession()
sess.mu.Lock()
sess.config = s
sess.id = id
sess.fresh = false
sess.data.Lock()
decodeErr := sess.decodeSessionData(rawData)
sess.data.Unlock()
sess.mu.Unlock()
if decodeErr != nil {
sess.Release()
return nil, fmt.Errorf("failed to decode session data: %w", decodeErr)
}
if s.AbsoluteTimeout > 0 {
if sess.isAbsExpired() {
if err := sess.Destroy(); err != nil { //nolint:contextcheck // it is not right
sess.Release()
log.Errorf("failed to destroy session: %v", err)
}
return nil, ErrSessionIDNotFoundInStore
}
}
return sess, nil
}
================================================
FILE: middleware/session/store_test.go
================================================
package session
import (
"context"
"fmt"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/extractors"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
// go test -run Test_Store_getSessionID
func Test_Store_getSessionID(t *testing.T) {
t.Parallel()
expectedID := "test-session-id"
// fiber instance
app := fiber.New()
t.Run("from cookie", func(t *testing.T) {
t.Parallel()
// session store
store := NewStore()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set cookie
ctx.Request().Header.SetCookie(store.Extractor.Key, expectedID)
require.Equal(t, expectedID, store.getSessionID(ctx))
})
t.Run("from header", func(t *testing.T) {
t.Parallel()
// session store
store := NewStore(Config{
Extractor: extractors.FromHeader("session_id"),
})
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set header
ctx.Request().Header.Set(store.Extractor.Key, expectedID)
require.Equal(t, expectedID, store.getSessionID(ctx))
})
t.Run("from url query", func(t *testing.T) {
t.Parallel()
// session store
store := NewStore(Config{
Extractor: extractors.FromQuery("session_id"),
})
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set url parameter
ctx.Request().SetRequestURI(fmt.Sprintf("/path?%s=%s", store.Extractor.Key, expectedID))
require.Equal(t, expectedID, store.getSessionID(ctx))
})
}
// go test -run Test_Store_Get
// Regression: https://github.com/gofiber/fiber/issues/1408
// Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v
func Test_Store_Get(t *testing.T) {
// Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v
t.Parallel()
unexpectedID := "test-session-id"
// fiber instance
app := fiber.New()
t.Run("session should be re-generated if it is invalid", func(t *testing.T) {
t.Parallel()
// session store
store := NewStore()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// set cookie
ctx.Request().Header.SetCookie(store.Extractor.Key, unexpectedID)
acquiredSession, err := store.Get(ctx)
require.NoError(t, err)
require.NotEqual(t, unexpectedID, acquiredSession.ID())
})
}
// go test -run Test_Store_DeleteSession
func Test_Store_DeleteSession(t *testing.T) {
t.Parallel()
// fiber instance
app := fiber.New()
// session store
store := NewStore()
// fiber context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// Create a new session
session, err := store.Get(ctx)
require.NoError(t, err)
// Save the session ID
sessionID := session.ID()
// Delete the session
err = store.Delete(ctx, sessionID)
require.NoError(t, err)
// Try to get the session again
session, err = store.Get(ctx)
require.NoError(t, err)
// The session ID should be different now, because the old session was deleted
require.NotEqual(t, sessionID, session.ID())
}
func TestStore_Get_SessionAlreadyLoaded(t *testing.T) {
// Create a new Fiber app
app := fiber.New()
// Create a new context
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
// Mock middleware and set it in the context
middleware := &Middleware{}
ctx.Locals(middlewareContextKey, middleware)
// Create a new store
store := &Store{}
// Call the Get method
sess, err := store.Get(ctx)
// Assert that the error is ErrSessionAlreadyLoadedByMiddleware
require.Nil(t, sess)
require.Equal(t, ErrSessionAlreadyLoadedByMiddleware, err)
}
func TestStore_Delete(t *testing.T) {
// Create a new store
store := NewStore()
t.Run("delete with empty session ID", func(t *testing.T) {
err := store.Delete(context.Background(), "")
require.Error(t, err)
require.Equal(t, ErrEmptySessionID, err)
})
t.Run("delete non-existing session", func(t *testing.T) {
err := store.Delete(context.Background(), "non-existing-session-id")
require.NoError(t, err)
})
}
func Test_Store_GetByID(t *testing.T) {
t.Parallel()
// Create a new store
store := NewStore()
t.Run("empty session ID", func(t *testing.T) {
t.Parallel()
sess, err := store.GetByID(context.Background(), "")
require.Error(t, err)
require.Nil(t, sess)
require.Equal(t, ErrEmptySessionID, err)
})
t.Run("nonexistent session ID", func(t *testing.T) {
t.Parallel()
sess, err := store.GetByID(context.Background(), "nonexistent-session-id")
require.Error(t, err)
require.Nil(t, sess)
require.Equal(t, ErrSessionIDNotFoundInStore, err)
})
t.Run("valid session ID", func(t *testing.T) {
t.Parallel()
app := fiber.New()
// Create a new session
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
session, err := store.Get(ctx)
defer session.Release()
defer app.ReleaseCtx(ctx)
require.NoError(t, err)
// Save the session ID
sessionID := session.ID()
// Save the session
err = session.Save()
require.NoError(t, err)
// Retrieve the session by ID
retrievedSession, err := store.GetByID(context.Background(), sessionID)
require.NoError(t, err)
require.NotNil(t, retrievedSession)
require.Equal(t, sessionID, retrievedSession.ID())
// Call Save on the retrieved session
retrievedSession.Set("key", "value")
err = retrievedSession.Save()
require.NoError(t, err)
// Call Other Session methods
require.Equal(t, "value", retrievedSession.Get("key"))
require.False(t, retrievedSession.Fresh())
require.NoError(t, retrievedSession.Reset())
require.NoError(t, retrievedSession.Destroy())
require.IsType(t, []any{}, retrievedSession.Keys())
require.NoError(t, retrievedSession.Regenerate())
require.NotPanics(t, func() {
retrievedSession.Release()
})
})
}
================================================
FILE: middleware/skip/skip.go
================================================
package skip
import (
"github.com/gofiber/fiber/v3"
)
// New returns a middleware that calls the provided predicate for each request.
// If the predicate evaluates to true the wrapped handler is skipped and the next
// handler in the chain is executed.
func New(handler fiber.Handler, exclude func(c fiber.Ctx) bool) fiber.Handler {
if exclude == nil {
return handler
}
return func(c fiber.Ctx) error {
if exclude(c) {
return c.Next()
}
return handler(c)
}
}
================================================
FILE: middleware/skip/skip_test.go
================================================
package skip_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/middleware/skip"
"github.com/stretchr/testify/require"
)
// go test -run Test_Skip
func Test_Skip(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(skip.New(errTeapotHandler, func(fiber.Ctx) bool { return true }))
app.Get("/", helloWorldHandler)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}
// go test -run Test_SkipFalse
func Test_SkipFalse(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(skip.New(errTeapotHandler, func(fiber.Ctx) bool { return false }))
app.Get("/", helloWorldHandler)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
}
// go test -run Test_SkipNilFunc
func Test_SkipNilFunc(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(skip.New(errTeapotHandler, nil))
app.Get("/", helloWorldHandler)
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusTeapot, resp.StatusCode)
}
func helloWorldHandler(c fiber.Ctx) error {
return c.SendString("Hello, World 👋!")
}
func errTeapotHandler(fiber.Ctx) error {
return fiber.ErrTeapot
}
================================================
FILE: middleware/static/config.go
================================================
package static
import (
"io/fs"
"time"
"github.com/gofiber/fiber/v3"
)
// Config defines the config for middleware.
type Config struct {
// FS is the file system to serve the static files from.
// You can use interfaces compatible with fs.FS like embed.FS, os.DirFS etc.
//
// Optional. Default: nil
FS fs.FS
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// ModifyResponse defines a function that allows you to alter the response.
//
// Optional. Default: nil
ModifyResponse fiber.Handler
// NotFoundHandler defines a function to handle when the path is not found.
//
// Optional. Default: nil
NotFoundHandler fiber.Handler
// The names of the index files for serving a directory.
//
// Optional. Default: []string{"index.html"}.
IndexNames []string `json:"index"`
// Expiration duration for inactive file handlers.
// Use a negative time.Duration to disable it.
//
// Optional. Default: 10 * time.Second.
CacheDuration time.Duration `json:"cache_duration"`
// The value for the Cache-Control HTTP-header
// that is set on the file response. MaxAge is defined in seconds.
//
// Optional. Default: 0.
MaxAge int `json:"max_age"`
// When set to true, the server tries minimizing CPU usage by caching compressed files.
// This works differently than the github.com/gofiber/compression middleware.
//
// Optional. Default: false
Compress bool `json:"compress"`
// When set to true, enables byte range requests.
//
// Optional. Default: false
ByteRange bool `json:"byte_range"`
// When set to true, enables directory browsing.
//
// Optional. Default: false.
Browse bool `json:"browse"`
// When set to true, enables direct download.
//
// Optional. Default: false.
Download bool `json:"download"`
}
// ConfigDefault is the default config
var ConfigDefault = Config{
IndexNames: []string{"index.html"},
CacheDuration: 10 * time.Second,
}
// Helper function to set default values
func configDefault(config ...Config) Config {
// Return default config if nothing provided
if len(config) < 1 {
return ConfigDefault
}
// Override default config
cfg := config[0]
// Set default values
if len(cfg.IndexNames) == 0 {
cfg.IndexNames = ConfigDefault.IndexNames
}
if cfg.CacheDuration == 0 {
cfg.CacheDuration = ConfigDefault.CacheDuration
}
return cfg
}
================================================
FILE: middleware/static/static.go
================================================
package static
import (
"bytes"
"errors"
"fmt"
"io/fs"
"net/url"
"os"
pathpkg "path"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v3"
)
var ErrInvalidPath = errors.New("invalid path")
// sanitizePath validates and cleans the requested path.
// It returns an error if the path attempts to traverse directories.
func sanitizePath(p []byte, filesystem fs.FS) ([]byte, error) {
var s string
hasTrailingSlash := len(p) > 0 && p[len(p)-1] == '/'
if bytes.IndexByte(p, '\\') >= 0 {
b := make([]byte, len(p))
copy(b, p)
for i := range b {
if b[i] == '\\' {
b[i] = '/'
}
}
s = utils.UnsafeString(b)
} else {
s = utils.UnsafeString(p)
}
// repeatedly unescape until it no longer changes, catching errors
for strings.IndexByte(s, '%') >= 0 {
us, err := url.PathUnescape(s)
if err != nil {
return nil, ErrInvalidPath
}
if us == s {
break
}
s = us
}
if strings.IndexByte(s, '\\') >= 0 {
return nil, ErrInvalidPath
}
// reject any null bytes
if strings.IndexByte(s, '\x00') >= 0 {
return nil, ErrInvalidPath
}
normalized := filepath.ToSlash(s)
if filesystem == nil && strings.HasPrefix(normalized, "//") {
return nil, ErrInvalidPath
}
s = pathpkg.Clean("/" + normalized)
trimmed := utils.TrimLeft(s, '/')
if trimmed != "" {
if slices.Contains(strings.Split(trimmed, "/"), "..") {
return nil, ErrInvalidPath
}
}
if filesystem == nil {
normalizedClean := filepath.ToSlash(trimmed)
if strings.HasPrefix(normalizedClean, "//") {
return nil, ErrInvalidPath
}
if volume := filepath.VolumeName(normalizedClean); volume != "" {
return nil, ErrInvalidPath
}
if len(normalizedClean) >= 2 && normalizedClean[1] == ':' {
drive := normalizedClean[0]
if (drive >= 'a' && drive <= 'z') || (drive >= 'A' && drive <= 'Z') {
return nil, ErrInvalidPath
}
}
if strings.HasPrefix(filepath.ToSlash(s), "//") {
return nil, ErrInvalidPath
}
}
if filesystem != nil {
s = trimmed
if s == "" {
return []byte("/"), nil
}
if !fs.ValidPath(s) {
return nil, ErrInvalidPath
}
s = "/" + s
}
if hasTrailingSlash && len(s) > 1 && s[len(s)-1] != '/' {
s += "/"
}
return utils.UnsafeBytes(s), nil
}
// New creates a new middleware handler.
// The root argument specifies the root directory from which to serve static assets.
//
// Note: Root has to be string or fs.FS; otherwise, it will panic.
func New(root string, cfg ...Config) fiber.Handler {
config := configDefault(cfg...)
var createFS sync.Once
var fileHandler fasthttp.RequestHandler
var cacheControlValue string
var rootIsFile bool
// adjustments for io/fs compatibility
if config.FS != nil && root == "" {
root = "."
}
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
if config.Next != nil && config.Next(c) {
return c.Next()
}
// We only serve static assets on GET or HEAD methods
method := c.Method()
if method != fiber.MethodGet && method != fiber.MethodHead {
return c.Next()
}
// Initialize FS
createFS.Do(func() {
prefix := c.Route().Path
if check, err := isFile(root, config.FS); err == nil {
rootIsFile = check
}
// Is prefix a partial wildcard?
if before, _, found := strings.Cut(prefix, "*"); found {
// /john* -> /john
prefix = before
}
prefixLen := len(prefix)
if prefixLen > 1 && prefix[prefixLen-1:] == "/" {
// /john/ -> /john
prefixLen--
}
fileServer := &fasthttp.FS{
Root: root,
FS: config.FS,
AllowEmptyRoot: true,
GenerateIndexPages: config.Browse,
AcceptByteRange: config.ByteRange,
Compress: config.Compress,
CompressBrotli: config.Compress, // Brotli compression won't work without this
CompressZstd: config.Compress, // Zstd compression won't work without this
CompressedFileSuffixes: c.App().Config().CompressedFileSuffixes,
CacheDuration: config.CacheDuration,
SkipCache: config.CacheDuration < 0,
IndexNames: config.IndexNames,
PathNotFound: func(fctx *fasthttp.RequestCtx) {
fctx.Response.SetStatusCode(fiber.StatusNotFound)
},
}
fileServer.PathRewrite = func(fctx *fasthttp.RequestCtx) []byte {
path := fctx.Path()
if len(path) >= prefixLen {
checkFile, err := isFile(root, fileServer.FS)
if err != nil {
return path
}
// If the root is a file, we need to reset the path to "/" always.
switch {
case checkFile && fileServer.FS == nil:
path = []byte("/")
case checkFile && fileServer.FS != nil:
path = utils.UnsafeBytes(root)
default:
path = path[prefixLen:]
if len(path) == 0 || path[len(path)-1] != '/' {
path = append(path, '/')
}
}
}
if len(path) > 0 && path[0] != '/' {
path = append([]byte("/"), path...)
}
sanitized, err := sanitizePath(path, fileServer.FS)
if err != nil {
// return a guaranteed-missing path so fs responds with 404
return []byte("/__fiber_invalid__")
}
return sanitized
}
maxAge := config.MaxAge
if maxAge > 0 {
cacheControlValue = "public, max-age=" + strconv.Itoa(maxAge)
}
fileHandler = fileServer.NewRequestHandler()
})
// Serve file
fileHandler(c.RequestCtx())
// Sets the response Content-Disposition header to attachment if the Download option is true
if config.Download {
name := filepath.Base(c.Path())
if rootIsFile {
name = filepath.Base(root)
}
c.Attachment(name)
}
// Return request if found and not forbidden
status := c.RequestCtx().Response.StatusCode()
if status != fiber.StatusNotFound && status != fiber.StatusForbidden {
if cacheControlValue != "" {
c.RequestCtx().Response.Header.Set(fiber.HeaderCacheControl, cacheControlValue)
}
if config.ModifyResponse != nil {
return config.ModifyResponse(c)
}
return nil
}
// Return custom 404 handler if provided.
if config.NotFoundHandler != nil {
return config.NotFoundHandler(c)
}
// Reset response to default
c.RequestCtx().SetContentType("") // Issue #420
c.RequestCtx().Response.SetStatusCode(fiber.StatusOK)
c.RequestCtx().Response.SetBodyString("")
// Next middleware
return c.Next()
}
}
// isFile checks if the root is a file.
func isFile(root string, filesystem fs.FS) (bool, error) {
var file fs.File
var err error
if filesystem != nil {
file, err = filesystem.Open(root)
if err != nil {
return false, fmt.Errorf("static: %w", err)
}
defer func() {
_ = file.Close() //nolint:errcheck // not needed
}()
} else {
file, err = os.Open(filepath.Clean(root))
if err != nil {
return false, fmt.Errorf("static: %w", err)
}
defer func() {
_ = file.Close() //nolint:errcheck // not needed
}()
}
stat, err := file.Stat()
if err != nil {
return false, fmt.Errorf("static: %w", err)
}
return stat.Mode().IsRegular(), nil
}
================================================
FILE: middleware/static/static_test.go
================================================
package static
import (
"embed"
"io"
"io/fs"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/gofiber/fiber/v3"
)
const (
winOS = "windows"
testCSSDir = "../../.github/testdata/fs/css"
)
var testConfig = fiber.TestConfig{
Timeout: 10 * time.Second,
FailOnTimeout: true,
}
// go test -run Test_Static_Index_Default
func Test_Static_Index_Default(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/prefix", New("../../.github/workflows"))
app.Get("", New("../../.github/"))
app.Get("test", New("", Config{
IndexNames: []string{"index.html"},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "Hello, World!")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/not-found", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 404, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextPlainCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Not Found", string(body))
}
// go test -run Test_Static_Index
func Test_Static_Direct(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/*", New("../../.github"))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/index.html", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "Hello, World!")
resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/index.html", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 405, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextPlainCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/testdata/testRoutes.json", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMEApplicationJSON, resp.Header.Get("Content-Type"))
require.Empty(t, resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "test_routes")
}
// go test -run Test_Static_MaxAge
func Test_Static_MaxAge(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/*", New("../../.github", Config{
MaxAge: 100,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/index.html", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, "text/html; charset=utf-8", resp.Header.Get(fiber.HeaderContentType))
require.Equal(t, "public, max-age=100", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
}
// go test -run Test_Static_Custom_CacheControl
func Test_Static_Custom_CacheControl(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/*", New("../../.github", Config{
ModifyResponse: func(c fiber.Ctx) error {
if strings.Contains(c.GetRespHeader("Content-Type"), "text/html") {
c.Response().Header.Set("Cache-Control", "no-cache, no-store, must-revalidate")
}
return nil
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/index.html", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "no-cache, no-store, must-revalidate", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
normalResp, normalErr := app.Test(httptest.NewRequest(fiber.MethodGet, "/config.yml", http.NoBody))
require.NoError(t, normalErr, "app.Test(req)")
require.Empty(t, normalResp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
}
func Test_Static_Disable_Cache(t *testing.T) {
// Skip on Windows. It's not possible to delete a file that is in use.
if runtime.GOOS == winOS {
t.SkipNow()
}
t.Parallel()
app := fiber.New()
file, err := os.Create("../../.github/test.txt")
require.NoError(t, err)
_, err = file.WriteString("Hello, World!")
require.NoError(t, err)
require.NoError(t, file.Close())
// Remove the file even if the test fails
defer func() {
_ = os.Remove("../../.github/test.txt") //nolint:errcheck // not needed
}()
app.Get("/*", New("../../.github/", Config{
CacheDuration: -1,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test.txt", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Empty(t, resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "Hello, World!")
require.NoError(t, os.Remove("../../.github/test.txt"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/test.txt", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Empty(t, resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control")
body, err = io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Not Found", string(body))
}
func Test_Static_NotFoundHandler(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/*", New("../../.github", Config{
NotFoundHandler: func(c fiber.Ctx) error {
return c.SendString("Custom 404")
},
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/not-found", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 404, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextPlainCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Custom 404", string(body))
}
// go test -run Test_Static_Download
func Test_Static_Download(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/fiber.png", New("../../.github/testdata/fs/img/fiber.png", Config{
Download: true,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fiber.png", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, "image/png", resp.Header.Get(fiber.HeaderContentType))
require.Equal(t, `attachment; filename="fiber.png"`, resp.Header.Get(fiber.HeaderContentDisposition))
}
func Test_Static_Download_NonASCII(t *testing.T) {
// Skip on Windows. It's not possible to delete a file that is in use.
if runtime.GOOS == "windows" {
t.SkipNow()
}
t.Parallel()
dir := t.TempDir()
fname := "файл.txt"
path := filepath.Join(dir, fname)
require.NoError(t, os.WriteFile(path, []byte("x"), 0o644)) //nolint:gosec // Not a concern
app := fiber.New()
app.Get("/file", New(path, Config{Download: true}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/file", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
expect := "attachment; filename=\"" + fname + "\"; filename*=UTF-8''" + url.PathEscape(fname)
require.Equal(t, expect, resp.Header.Get(fiber.HeaderContentDisposition))
}
// go test -run Test_Static_Group
func Test_Static_Group(t *testing.T) {
t.Parallel()
app := fiber.New()
grp := app.Group("/v1", func(c fiber.Ctx) error {
c.Set("Test-Header", "123")
return c.Next()
})
grp.Get("/v2*", New("../../.github/index.html"))
req := httptest.NewRequest(fiber.MethodGet, "/v1/v2", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
require.Equal(t, "123", resp.Header.Get("Test-Header"))
grp = app.Group("/v2")
grp.Get("/v3*", New("../../.github/index.html"))
req = httptest.NewRequest(fiber.MethodGet, "/v2/v3/john/doe", http.NoBody)
resp, err = app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
}
func Test_Static_Wildcard(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("*", New("../../.github/index.html"))
req := httptest.NewRequest(fiber.MethodGet, "/yesyes/john/doe", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "Test file")
}
func Test_Static_Prefix_Wildcard(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test*", New("../../.github/index.html"))
req := httptest.NewRequest(fiber.MethodGet, "/test/john/doe", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
app.Get("/my/nameisjohn*", New("../../.github/index.html"))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/my/nameisjohn/no/its/not", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "Test file")
}
func Test_Static_Prefix(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/john*", New("../../.github"))
req := httptest.NewRequest(fiber.MethodGet, "/john/index.html", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
app.Get("/prefix*", New("../../.github/testdata"))
req = httptest.NewRequest(fiber.MethodGet, "/prefix/index.html", http.NoBody)
resp, err = app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
app.Get("/single*", New("../../.github/testdata/testRoutes.json"))
req = httptest.NewRequest(fiber.MethodGet, "/single", http.NoBody)
resp, err = app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMEApplicationJSON, resp.Header.Get(fiber.HeaderContentType))
}
func Test_Static_Trailing_Slash(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/john*", New("../../.github"))
req := httptest.NewRequest(fiber.MethodGet, "/john/", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
app.Get("/john_without_index*", New(testCSSDir))
req = httptest.NewRequest(fiber.MethodGet, "/john_without_index/", http.NoBody)
resp, err = app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 404, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextPlainCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
app.Use("/john", New("../../.github"))
req = httptest.NewRequest(fiber.MethodGet, "/john/", http.NoBody)
resp, err = app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
req = httptest.NewRequest(fiber.MethodGet, "/john", http.NoBody)
resp, err = app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
app.Use("/john_without_index/", New(testCSSDir))
req = httptest.NewRequest(fiber.MethodGet, "/john_without_index/", http.NoBody)
resp, err = app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 404, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextPlainCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
}
func Test_Static_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/*", New("../../.github", Config{
Next: func(c fiber.Ctx) bool {
return c.Get("X-Custom-Header") == "skip"
},
}))
app.Get("/*", func(c fiber.Ctx) error {
return c.SendString("You've skipped app.Static")
})
t.Run("app.Static is skipped: invoking Get handler", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("X-Custom-Header", "skip")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextPlainCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "You've skipped app.Static")
})
t.Run("app.Static is not skipped: serving index.html", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
req.Header.Set("X-Custom-Header", "don't skip")
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "Hello, World!")
})
}
func Test_Route_Static_Root(t *testing.T) {
t.Parallel()
dir := testCSSDir
app := fiber.New()
app.Get("/*", New(dir, Config{
Browse: true,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/style.css", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "color")
app = fiber.New()
app.Get("/*", New(dir))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 404, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/style.css", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "color")
}
func Test_Route_Static_HasPrefix(t *testing.T) {
t.Parallel()
dir := testCSSDir
app := fiber.New()
app.Get("/static*", New(dir, Config{
Browse: true,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/static", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/style.css", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "color")
app = fiber.New()
app.Get("/static/*", New(dir, Config{
Browse: true,
}))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/style.css", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "color")
app = fiber.New()
app.Get("/static*", New(dir))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 404, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 404, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/style.css", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "color")
app = fiber.New()
app.Get("/static*", New(dir))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 404, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 404, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/style.css", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "color")
}
func Test_Static_FS(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/*", New("", Config{
FS: os.DirFS("../../.github/testdata/fs"),
Browse: true,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/css/style.css", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, fiber.MIMETextCSSCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "color")
}
/*func Test_Static_FS_DifferentRoot(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/*", New("fs", Config{
FS: os.DirFS("../../.github/testdata"),
IndexNames: []string{"index2.html"},
Browse: true,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "
Hello, World!
")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/css/style.css", nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, fiber.MIMETextCSSCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "color")
}*/
//go:embed static.go config.go
var fsTestFilesystem embed.FS
func Test_Static_FS_Browse(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/embed*", New("", Config{
FS: fsTestFilesystem,
Browse: true,
}))
app.Get("/dirfs*", New("", Config{
FS: os.DirFS(testCSSDir),
Browse: true,
}))
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/dirfs", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "style.css")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/dirfs/style.css", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, fiber.MIMETextCSSCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "color")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/dirfs/test", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/dirfs/test/style2.css", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, fiber.MIMETextCSSCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "color")
resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/embed", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "static.go")
}
func Test_Static_FS_Prefix_Wildcard(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test*", New("index.html", Config{
FS: os.DirFS("../../.github"),
IndexNames: []string{"not_index.html"},
}))
req := httptest.NewRequest(fiber.MethodGet, "/test/john/doe", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength))
require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(body), "Test file")
}
func Test_isFile(t *testing.T) {
t.Parallel()
cases := []struct {
filesystem fs.FS
gotError error
name string
path string
expected bool
}{
{
name: "file",
path: "index.html",
filesystem: os.DirFS("../../.github"),
expected: true,
},
{
name: "file",
path: "index2.html",
filesystem: os.DirFS("../../.github"),
expected: false,
gotError: fs.ErrNotExist,
},
{
name: "directory",
path: ".",
filesystem: os.DirFS("../../.github"),
expected: false,
},
{
name: "directory",
path: "not_exists",
filesystem: os.DirFS("../../.github"),
expected: false,
gotError: fs.ErrNotExist,
},
{
name: "directory",
path: ".",
filesystem: os.DirFS(testCSSDir),
expected: false,
},
{
name: "file",
path: testCSSDir + "/style.css",
filesystem: nil,
expected: true,
},
{
name: "file",
path: testCSSDir + "/style2.css",
filesystem: nil,
expected: false,
gotError: fs.ErrNotExist,
},
{
name: "directory",
path: testCSSDir,
filesystem: nil,
expected: false,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
c := c
t.Parallel()
actual, err := isFile(c.path, c.filesystem)
require.ErrorIs(t, err, c.gotError)
require.Equal(t, c.expected, actual)
})
}
}
func Test_Static_Compress(t *testing.T) {
t.Parallel()
dir := "../../.github/testdata/fs"
app := fiber.New()
app.Get("/*", New(dir, Config{
Compress: true,
}))
// Note: deflate is not supported by fasthttp.FS
algorithms := []string{"zstd", "gzip", "br"}
for _, algo := range algorithms {
t.Run(algo+"_compression", func(t *testing.T) {
t.Parallel()
// request non-compressible file (less than 200 bytes), Content Length will remain the same
req := httptest.NewRequest(fiber.MethodGet, "/css/style.css", http.NoBody)
req.Header.Set("Accept-Encoding", algo)
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
require.Equal(t, "46", resp.Header.Get(fiber.HeaderContentLength))
// request compressible file, ContentLength will change
req = httptest.NewRequest(fiber.MethodGet, "/index.html", http.NoBody)
req.Header.Set("Accept-Encoding", algo)
resp, err = app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, algo, resp.Header.Get(fiber.HeaderContentEncoding))
require.Greater(t, "299", resp.Header.Get(fiber.HeaderContentLength))
})
}
}
func Test_Static_Compress_WithoutEncoding(t *testing.T) {
t.Parallel()
dir := "../../.github/testdata/fs"
app := fiber.New()
app.Get("/*", New(dir, Config{
Compress: true,
CacheDuration: 1 * time.Second,
}))
// request compressible file without encoding
req := httptest.NewRequest(fiber.MethodGet, "/index.html", http.NoBody)
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding))
require.Equal(t, "299", resp.Header.Get(fiber.HeaderContentLength))
// request compressible file with different encodings
algorithms := []string{"zstd", "gzip", "br"}
fileSuffixes := map[string]string{
"gzip": ".fiber.gz",
"br": ".fiber.br",
"zstd": ".fiber.zst",
}
for _, algo := range algorithms {
// Wait for cache to expire
time.Sleep(2 * time.Second)
fileName := "index.html"
compressedFileName := dir + "/index.html" + fileSuffixes[algo]
req = httptest.NewRequest(fiber.MethodGet, "/"+fileName, http.NoBody)
req.Header.Set("Accept-Encoding", algo)
resp, err = app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, algo, resp.Header.Get(fiber.HeaderContentEncoding))
require.Greater(t, "299", resp.Header.Get(fiber.HeaderContentLength))
// verify suffixed file was created
_, err := os.Stat(compressedFileName)
require.NoError(t, err, "File should exist")
}
}
func Test_Static_Compress_WithFileSuffixes(t *testing.T) {
t.Parallel()
dir := "../../.github/testdata/fs"
fileSuffixes := map[string]string{
"gzip": ".test.gz",
"br": ".test.br",
"zstd": ".test.zst",
}
app := fiber.New(fiber.Config{
CompressedFileSuffixes: fileSuffixes,
})
app.Get("/*", New(dir, Config{
Compress: true,
CacheDuration: 1 * time.Second,
}))
// request compressible file with different encodings
algorithms := []string{"zstd", "gzip", "br"}
for _, algo := range algorithms {
// Wait for cache to expire
time.Sleep(2 * time.Second)
fileName := "index.html"
compressedFileName := dir + "/index.html" + fileSuffixes[algo]
req := httptest.NewRequest(fiber.MethodGet, "/"+fileName, http.NoBody)
req.Header.Set("Accept-Encoding", algo)
resp, err := app.Test(req, testConfig)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, algo, resp.Header.Get(fiber.HeaderContentEncoding))
require.Greater(t, "299", resp.Header.Get(fiber.HeaderContentLength))
// verify suffixed file was created
_, err = os.Stat(compressedFileName)
require.NoError(t, err, "File should exist")
}
}
func Test_Router_Mount_n_Static(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use("/static", New(testCSSDir, Config{Browse: true}))
app.Get("/", func(c fiber.Ctx) error {
return c.SendString("Home")
})
subApp := fiber.New()
app.Use("/mount", subApp)
subApp.Get("/test", func(c fiber.Ctx) error {
return c.SendString("Hello from /test")
})
app.Use(func(c fiber.Ctx) error {
return c.Status(fiber.StatusNotFound).SendString("Not Found")
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/static/style.css", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
}
func Test_Static_PathTraversal(t *testing.T) {
// Skip this test if running on Windows
if runtime.GOOS == winOS {
t.Skip("Skipping Windows-specific tests")
}
t.Parallel()
app := fiber.New()
// Serve only from testCSSDir
// This directory should contain `style.css` but not `index.html` or anything above it.
rootDir := testCSSDir
app.Get("/*", New(rootDir))
// A valid request: should succeed
validReq := httptest.NewRequest(fiber.MethodGet, "/style.css", http.NoBody)
validResp, err := app.Test(validReq)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, validResp.StatusCode, "Status code")
require.Equal(t, fiber.MIMETextCSSCharsetUTF8, validResp.Header.Get(fiber.HeaderContentType))
validBody, err := io.ReadAll(validResp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(validBody), "color")
// Helper function to assert that a given path is blocked.
// Blocked can mean different status codes depending on what triggered the block.
// We'll accept 400 or 404 as "blocked" statuses:
// - 404 is the expected blocked response in most cases.
// - 400 might occur if fasthttp rejects the request before it's even processed (e.g., null bytes).
assertTraversalBlocked := func(path string) {
req := httptest.NewRequest(fiber.MethodGet, path, http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req)")
status := resp.StatusCode
require.Truef(t, status == 400 || status == 404,
"Status code for path traversal %s should be 400 or 404, got %d", path, status)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// If we got a 404, we expect the "Not Found" message because that's how fiber handles NotFound by default.
if status == 404 {
require.Contains(t, string(body), "Not Found",
"Blocked traversal should have a \"Not Found\" message for %s", path)
} else {
require.Contains(t, string(body), "Are you a hacker?",
"Blocked traversal should have a \"Not Found\" message for %s", path)
}
}
// Basic attempts to escape the directory
assertTraversalBlocked("/index.html..")
assertTraversalBlocked("/style.css..")
assertTraversalBlocked("/../index.html")
assertTraversalBlocked("/../../index.html")
assertTraversalBlocked("/../../../index.html")
// Attempts with double slashes
assertTraversalBlocked("//../index.html")
assertTraversalBlocked("/..//index.html")
// Encoded attempts: `%2e` is '.' and `%2f` is '/'
assertTraversalBlocked("/..%2findex.html") // ../index.html
assertTraversalBlocked("/%2e%2e/index.html") // ../index.html
assertTraversalBlocked("/%2e%2e%2f%2e%2e/secret") // ../../../secret
// Mixed encoded and normal attempts
assertTraversalBlocked("/%2e%2e/../index.html") // ../../index.html
assertTraversalBlocked("/..%2f..%2fsecret.json") // ../../../secret.json
// Attempts with current directory references
assertTraversalBlocked("/./../index.html")
assertTraversalBlocked("/././../index.html")
// Trailing slashes
assertTraversalBlocked("/../")
assertTraversalBlocked("/../../")
// Attempts to load files from an absolute path outside the root
assertTraversalBlocked("/" + rootDir + "/../../index.html")
// Additional edge cases:
// Double-encoded `..`
assertTraversalBlocked("/%252e%252e/index.html") // double-encoded .. -> ../index.html after double decoding
// Multiple levels of encoding and traversal
assertTraversalBlocked("/%2e%2e%2F..%2f%2e%2e%2fWINDOWS") // multiple ups and unusual pattern
assertTraversalBlocked("/%2e%2e%2F..%2f%2e%2e%2f%2e%2e/secret") // more complex chain of ../
// Null byte attempts
assertTraversalBlocked("/index.html%00.jpg")
assertTraversalBlocked("/%00index.html")
assertTraversalBlocked("/somefolder%00/something")
assertTraversalBlocked("/%00/index.html")
// Attempts to access known system files
assertTraversalBlocked("/etc/passwd")
assertTraversalBlocked("/etc/")
// Complex mixed attempts with encoded slashes and dots
assertTraversalBlocked("/..%2F..%2F..%2F..%2Fetc%2Fpasswd")
// Attempts inside subdirectories with encoded traversal
assertTraversalBlocked("/somefolder/%2e%2e%2findex.html")
assertTraversalBlocked("/somefolder/%2e%2e%2f%2e%2e%2findex.html")
// Backslash encoded attempts
assertTraversalBlocked("/%5C..%5Cindex.html")
assertTraversalBlocked("/%5c..%5c..%5cetc%5cpasswd")
assertTraversalBlocked("/%255c..%255c..%255cetc%255cpasswd")
assertTraversalBlocked("/..%5c..%5cetc%5cpasswd")
assertTraversalBlocked("/%2e%2e%5c%2e%2e%5cetc%5cpasswd")
assertTraversalBlocked("/%2e%2e%2f%2e%2e%5cetc%5cpasswd")
assertTraversalBlocked("/%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd")
assertTraversalBlocked("/.%2e/.%2e/etc/passwd")
assertTraversalBlocked("/..%2f..%2f..%2f..%2fetc%2fpasswd")
assertTraversalBlocked("/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd")
assertTraversalBlocked("/%2e%2e%2f%2e%2e%2fetc%2fshadow")
assertTraversalBlocked("/%2e%2e/%2e%2e/var/log/auth.log")
assertTraversalBlocked("/..%2f..%2fvar%2flog%2fauth.log")
assertTraversalBlocked("/%2e%2e//%2e%2e//etc/passwd")
assertTraversalBlocked("/..//..//etc/passwd")
assertTraversalBlocked("/%2e%2e%2f%2e%2e%2f%2e%2e%2fproc%2fself%2fenviron")
assertTraversalBlocked("/%2e%2e%2f%2e%2e%2f%2e%2e%2froot%2f.ssh%2fauthorized_keys")
}
func Test_Static_PathTraversal_WindowsOnly(t *testing.T) {
// Skip this test if not running on Windows
if runtime.GOOS != winOS {
t.Skip("Skipping Windows-specific tests")
}
t.Parallel()
app := fiber.New()
// Serve only from testCSSDir
rootDir := testCSSDir
app.Get("/*", New(rootDir))
// A valid request (relative path without backslash):
validReq := httptest.NewRequest(fiber.MethodGet, "/style.css", http.NoBody)
validResp, err := app.Test(validReq)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, validResp.StatusCode, "Status code for valid file on Windows")
body, err := io.ReadAll(validResp.Body)
require.NoError(t, err, "app.Test(req)")
require.Contains(t, string(body), "color")
// Helper to test blocked responses
assertTraversalBlocked := func(path string) {
req := httptest.NewRequest(fiber.MethodGet, path, http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req)")
// We expect a blocked request to return either 400 or 404
status := resp.StatusCode
require.Containsf(t, []int{400, 404}, status,
"Status code for path traversal %s should be 400 or 404, got %d", path, status)
// If it's a 404, we expect a "Not Found" message
if status == 404 {
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Contains(t, string(respBody), "Not Found",
"Blocked traversal should have a \"Not Found\" message for %s", path)
} else {
require.Contains(t, string(body), "Are you a hacker?",
"Blocked traversal should have a \"Not Found\" message for %s", path)
}
}
// Windows-specific traversal attempts
// Backslashes are treated as directory separators on Windows.
assertTraversalBlocked("/..\\index.html")
assertTraversalBlocked("/..\\..\\index.html")
assertTraversalBlocked("/..\\..\\..\\Windows\\win.ini")
assertTraversalBlocked("/..\\..\\..\\Windows\\System32\\drivers\\etc\\hosts")
assertTraversalBlocked("/%5C..%5C..%5CWindows%5Cwin.ini")
assertTraversalBlocked("/%255C..%255C..%255CWindows%255Cwin.ini")
assertTraversalBlocked("/%5c..%5c..%5cWindows%5cSystem32%5cdrivers%5cetc%5chosts")
assertTraversalBlocked("/C:\\Windows\\System32\\cmd.exe")
assertTraversalBlocked("/C:%5CWindows%5CSystem32%5Ccmd.exe")
assertTraversalBlocked("/%43:%5CWindows%5CSystem32%5Ccmd.exe")
assertTraversalBlocked("/%5c%5cserver%5cshare%5csecret.txt")
assertTraversalBlocked("//server\\share\\secret.txt")
assertTraversalBlocked("//server/share/secret.txt")
assertTraversalBlocked("/%2F%2Fserver%2Fshare%2Fsecret.txt")
// Attempt with a path that might try to reference Windows drives or absolute paths
// Note: These are artificial tests to ensure no drive-letter escapes are allowed.
assertTraversalBlocked("/C:\\Windows\\System32\\cmd.exe")
assertTraversalBlocked("/C:/Windows/System32/cmd.exe")
// Attempt with UNC-like paths (though unlikely in a web context, good to test)
assertTraversalBlocked("//server\\share\\secret.txt")
// Attempt using a mixture of forward and backward slashes
assertTraversalBlocked("/..\\..\\/index.html")
// Attempt that includes a null-byte on Windows
assertTraversalBlocked("/index.html%00.txt")
// Check behavior on an obviously nonexistent and suspicious file
assertTraversalBlocked("/\\this\\path\\does\\not\\exist\\..")
// Attempts involving relative traversal and current directory reference
assertTraversalBlocked("/.\\../index.html")
assertTraversalBlocked("/./..\\index.html")
}
func Benchmark_SanitizePath(b *testing.B) {
bench := func(name string, filesystem fs.FS, path []byte) {
b.Run(name, func(b *testing.B) {
b.ReportAllocs()
for b.Loop() {
if _, err := sanitizePath(path, filesystem); err != nil {
b.Fatal(err)
}
}
})
}
bench("nilFS - urlencoded chars", nil, []byte("/foo%2Fbar/../baz%20qux/index.html"))
bench("dirFS - urlencoded chars", os.DirFS("."), []byte("/foo%2Fbar/../baz%20qux/index.html"))
bench("nilFS - slashes", nil, []byte("\\foo%2Fbar\\baz%20qux\\index.html"))
}
func Test_SanitizePath(t *testing.T) {
t.Parallel()
type testCase struct {
filesystem fs.FS
name string
expectPath string
input []byte
}
testCases := []testCase{
{name: "simple path", input: []byte("/foo/bar.txt"), expectPath: "/foo/bar.txt"},
{name: "traversal attempt", input: []byte("/foo/../../bar.txt"), expectPath: "/bar.txt"},
{name: "encoded traversal", input: []byte("/foo/%2e%2e/bar.txt"), expectPath: "/bar.txt"},
{name: "double encoded traversal", input: []byte("/%252e%252e/bar.txt"), expectPath: "/bar.txt"},
{name: "current dir reference", input: []byte("/foo/./bar.txt"), expectPath: "/foo/bar.txt"},
{name: "encoded slash", input: []byte("/foo%2Fbar.txt"), expectPath: "/foo/bar.txt"},
{name: "empty path", input: []byte(""), expectPath: "/"},
{name: "dot segments", input: []byte("/foo/./bar/../baz.txt"), expectPath: "/foo/baz.txt"},
{name: "leading dot segment", input: []byte("/./foo/bar.txt"), expectPath: "/foo/bar.txt"},
{name: "encoded space", input: []byte("/foo%20bar/baz.txt"), expectPath: "/foo bar/baz.txt"},
{name: "encoded plus literal", input: []byte("/foo+bar/baz.txt"), expectPath: "/foo+bar/baz.txt"},
// windows-specific paths
{name: "backslash path", input: []byte("\\foo\\bar.txt"), expectPath: "/foo/bar.txt"},
{name: "backslash traversal", input: []byte("\\foo\\..\\..\\bar.txt"), expectPath: "/bar.txt"},
{name: "mixed slashes", input: []byte("/foo\\bar.txt"), expectPath: "/foo/bar.txt"},
{name: "trailing slash preserved", input: []byte("/foo/bar/"), expectPath: "/foo/bar/"},
{name: "encoded trailing slash", input: []byte("/foo/bar%2F"), expectPath: "/foo/bar"},
{filesystem: os.DirFS("."), name: "filesystem empty path", input: []byte(""), expectPath: "/"},
{filesystem: os.DirFS("."), name: "filesystem trailing slash", input: []byte("/foo/"), expectPath: "/foo/"},
{filesystem: os.DirFS("."), name: "filesystem traversal clean", input: []byte("/foo/../bar.txt"), expectPath: "/bar.txt"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := sanitizePath(tc.input, tc.filesystem)
require.NoError(t, err)
require.Equal(t, tc.expectPath, string(got))
})
}
}
func Test_SanitizePath_Error(t *testing.T) {
t.Parallel()
type testCase struct {
filesystem fs.FS
name string
input []byte
}
testCases := []testCase{
{name: "null byte", input: []byte("/foo/bar.txt%00")},
{name: "encoded backslash traversal", input: []byte("/foo%5C..%5Cbar.txt")},
{name: "double encoded backslash traversal", input: []byte("/%255C..%255C..%255CWindows%255Cwin.ini")},
{name: "encoded backslash absolute", input: []byte("/%5CWindows%5CSystem32%5Cdrivers%5Cetc%5Chosts")},
{name: "double encoded backslash absolute", input: []byte("/%255CWindows%255CSystem32%255Cdrivers%255Cetc%255Chosts")},
{name: "encoded backslash mixed slashes", input: []byte("/..%5C..%5Cetc%5Cpasswd")},
{name: "encoded backslash mixed encoding", input: []byte("/%2e%2e%5c%2e%2e%5cetc%5cpasswd")},
{name: "encoded backslash with encoded slash", input: []byte("/%2e%2e%2f%2e%2e%5cetc%5cpasswd")},
{name: "encoded backslash unc path", input: []byte("//server%5Cshare%5Csecret.txt")},
{name: "encoded backslash drive letter", input: []byte("/C:%5CWindows%5CSystem32%5Ccmd.exe")},
{name: "double slash path", input: []byte("//foo//bar.txt")},
{name: "drive letter", input: []byte("C:/Windows/System32/cmd.exe")},
{name: "drive letter with leading slash", input: []byte("/C:/Windows/System32/cmd.exe")},
{name: "encoded drive letter", input: []byte("/%43:%5CWindows%5CSystem32%5Ccmd.exe")},
{name: "unc path", input: []byte("//server/share/secret.txt")},
{name: "encoded unc path", input: []byte("/%2F%2Fserver%2Fshare%2Fsecret.txt")},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, err := sanitizePath(tc.input, tc.filesystem)
require.ErrorIs(t, err, ErrInvalidPath, "Expected ErrInvalidPath for input: %s", tc.input)
})
}
}
================================================
FILE: middleware/timeout/config.go
================================================
package timeout
import (
"time"
"github.com/gofiber/fiber/v3"
)
// Config holds the configuration for the timeout middleware.
type Config struct {
// Next defines a function to skip this middleware.
// Optional. Default: nil
Next func(c fiber.Ctx) bool
// OnTimeout is executed when a timeout occurs.
// Optional. Default: nil (return fiber.ErrRequestTimeout)
OnTimeout fiber.Handler
// Errors defines custom errors that are treated as timeouts.
// Optional. Default: nil
Errors []error
// Timeout defines the timeout duration for all routes.
// Optional. Default: 0 (no timeout)
Timeout time.Duration
}
// ConfigDefault is the default configuration.
var ConfigDefault = Config{
Next: nil,
Timeout: 0,
OnTimeout: nil,
Errors: nil,
}
// configDefault returns the first Config value or ConfigDefault.
func configDefault(config ...Config) Config {
if len(config) < 1 {
return ConfigDefault
}
cfg := config[0]
if cfg.Timeout < 0 {
cfg.Timeout = ConfigDefault.Timeout
}
if cfg.Errors == nil {
cfg.Errors = ConfigDefault.Errors
}
if cfg.OnTimeout == nil {
cfg.OnTimeout = ConfigDefault.OnTimeout
}
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
return cfg
}
================================================
FILE: middleware/timeout/timeout.go
================================================
package timeout
import (
"context"
"errors"
"runtime/debug"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
)
// New enforces a timeout for each incoming request. It replaces the request's
// context with one that has the configured deadline, which is exposed through
// c.Context(). Handlers can detect the timeout by listening on c.Context().Done()
// and return early.
//
// When a timeout occurs, the middleware returns immediately with fiber.ErrRequestTimeout
// (or the result of OnTimeout if configured). The handler goroutine can continue
// safely, and resources are recycled when it finishes via the Abandon/ForceRelease
// mechanism.
func New(h fiber.Handler, config ...Config) fiber.Handler {
cfg := configDefault(config...)
return func(ctx fiber.Ctx) error {
if cfg.Next != nil && cfg.Next(ctx) {
return h(ctx)
}
timeout := cfg.Timeout
if timeout <= 0 {
return h(ctx)
}
// Create timeout context - handler can check c.Context().Done()
parent := ctx.Context()
tCtx, cancel := context.WithTimeout(parent, timeout)
ctx.SetContext(tCtx)
// Channels for handler result and panics
done := make(chan error, 1)
panicChan := make(chan any, 1)
// Run handler in goroutine so we can race against the timeout
go func() {
defer func() {
if p := recover(); p != nil {
log.Errorw("panic recovered in timeout handler", "panic", p, "stack", string(debug.Stack()))
select {
case panicChan <- p:
default:
// Middleware already returned, panic value discarded
}
}
}()
err := h(ctx)
select {
case done <- err:
default:
// Middleware already returned, error discarded
}
}()
// Wait for handler completion, panic, or timeout
select {
case err := <-done:
// Handler finished normally - cleanup and return
cancel()
ctx.SetContext(parent)
return handleResult(err, ctx, cfg)
case <-panicChan:
// Handler panicked - cleanup and return error
cancel()
ctx.SetContext(parent)
return fiber.ErrInternalServerError
case <-tCtx.Done():
// Timeout occurred - abandon context and return immediately
// The cleanup goroutine will cancel the timeout context once the handler finishes;
// the abandoned fiber.Ctx stays out of the pool.
return handleTimeout(parent, ctx, cancel, done, panicChan, cfg)
}
}
}
// handleResult processes the handler's return value
func handleResult(err error, ctx fiber.Ctx, cfg Config) error {
if err != nil && isTimeoutError(err, cfg.Errors) {
return invokeOnTimeout(ctx, cfg)
}
return err
}
// handleTimeout handles the timeout case using the Abandon mechanism
func handleTimeout(
parent context.Context,
ctx fiber.Ctx,
cancel context.CancelFunc,
done <-chan error,
panicChan <-chan any,
cfg Config,
) error {
// Mark fiber context as abandoned - ReleaseCtx will skip pooling.
// The context will NOT be returned to the pool. This is an intentional
// trade-off: we accept the small memory cost of not recycling timed-out
// contexts in exchange for complete race-freedom.
//
// This is the same approach fasthttp uses - timed-out RequestCtx objects
// are never returned to the pool (see fasthttp's releaseCtx which panics
// if timeoutResponse is set).
ctx.Abandon()
// Prepare the timeout response before marking the RequestCtx as timed out so
// custom OnTimeout handlers can shape the response body.
timeoutErr := invokeOnTimeout(ctx, cfg)
// If no OnTimeout handler is configured or the response is still the default
// 200/empty, ensure a sensible timeout response is captured for fasthttp to send.
if cfg.OnTimeout == nil || (ctx.Response().StatusCode() == fiber.StatusOK && len(ctx.Response().Body()) == 0) {
ctx.Response().SetStatusCode(fiber.StatusRequestTimeout)
if len(ctx.Response().Body()) == 0 {
ctx.Response().SetBodyString(fiber.ErrRequestTimeout.Message)
}
}
// Tell fasthttp to not recycle the RequestCtx - it will acquire a new one
// for the response and send the captured payload (either default or from
// OnTimeout). All ctx mutations after this call are ignored by fasthttp.
ctx.RequestCtx().TimeoutErrorWithResponse(&ctx.RequestCtx().Response)
// Spawn cleanup goroutine that waits for handler to finish.
// This only does context cleanup (cancel + restore parent), NOT ctx release.
// The fiber.Ctx is intentionally NOT released to avoid races with requestHandler
// which may still access ctx (e.g., ErrorHandler) after this function returns.
// ForceRelease cannot be called safely here for the same reason.
go func() {
select {
case <-done:
case <-panicChan:
}
// Handler finished - cancel timeout context and restore parent
cancel()
ctx.SetContext(parent)
// TODO: Currently the ctx is not returned to the pool (memory leak for timed-out requests).
// Future improvement: Implement a concurrent "garbage collector" list where abandoned
// contexts are queued after both the handler AND requestHandler are done. A background
// goroutine would periodically process this list and call ForceRelease() to recycle
// the contexts safely. This would require tracking when requestHandler finishes
// (e.g., via a channel signaled in ReleaseCtx) without adding per-request overhead
// for non-timeout cases.
}()
return timeoutErr
}
// invokeOnTimeout calls the OnTimeout handler if configured
func invokeOnTimeout(ctx fiber.Ctx, cfg Config) error {
if cfg.OnTimeout != nil {
return cfg.OnTimeout(ctx)
}
return fiber.ErrRequestTimeout
}
// isTimeoutError checks if err is a timeout-like error (context.DeadlineExceeded
// or any of the custom errors).
func isTimeoutError(err error, customErrors []error) bool {
if errors.Is(err, context.DeadlineExceeded) {
return true
}
if len(customErrors) > 0 {
for _, e := range customErrors {
if errors.Is(err, e) {
return true
}
}
}
return false
}
================================================
FILE: middleware/timeout/timeout_test.go
================================================
package timeout
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v3"
)
var (
// Custom error that we treat like a timeout when returned by the handler.
errCustomTimeout = errors.New("custom timeout error")
// Some unrelated error that should NOT trigger a request timeout.
errUnrelated = errors.New("unmatched error")
)
// sleepWithContext simulates a task that takes `d` time, but returns `te` if the context is canceled.
func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
timer := time.NewTimer(d)
defer timer.Stop() // Clean up the timer
select {
case <-ctx.Done():
return te
case <-timer.C:
return nil
}
}
// TestTimeout_Success tests a handler that completes within the allotted timeout.
func TestTimeout_Success(t *testing.T) {
t.Parallel()
app := fiber.New()
// Our middleware wraps a handler that sleeps for 10ms, well under the 50ms limit.
app.Get("/fast", New(func(c fiber.Ctx) error {
// Simulate some work
if err := sleepWithContext(c.Context(), 10*time.Millisecond, context.DeadlineExceeded); err != nil {
return err
}
return c.SendString("OK")
}, Config{Timeout: 50 * time.Millisecond}))
req := httptest.NewRequest(fiber.MethodGet, "/fast", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK for fast requests")
}
// TestTimeout_Exceeded tests a handler that exceeds the provided timeout.
func TestTimeout_Exceeded(t *testing.T) {
t.Parallel()
app := fiber.New()
// This handler listens for context cancelation and returns early when timeout occurs.
app.Get("/slow", New(func(c fiber.Ctx) error {
if err := sleepWithContext(c.Context(), 200*time.Millisecond, context.DeadlineExceeded); err != nil {
return err
}
return c.SendString("Should never get here")
}, Config{Timeout: 50 * time.Millisecond}))
req := httptest.NewRequest(fiber.MethodGet, "/slow", http.NoBody)
start := time.Now()
resp, err := app.Test(req)
elapsed := time.Since(start)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 Request Timeout")
// Handler should return shortly after timeout (not wait full 200ms)
require.Less(t, elapsed, 150*time.Millisecond, "handler should return early on context cancelation")
}
// TestTimeout_ContextPropagation verifies that the timeout context is properly
// passed to the handler so it can detect cancelation (Issue #3671).
func TestTimeout_ContextPropagation(t *testing.T) {
t.Parallel()
app := fiber.New()
errCh := make(chan error, 1)
app.Get("/context-aware", New(func(c fiber.Ctx) error {
timer := time.NewTimer(500 * time.Millisecond)
defer timer.Stop()
select {
case <-timer.C:
errCh <- nil
return c.SendString("completed")
case <-c.Context().Done():
errCh <- c.Context().Err()
return c.Context().Err()
}
}, Config{Timeout: 50 * time.Millisecond}))
req := httptest.NewRequest(fiber.MethodGet, "/context-aware", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode)
safety := time.NewTimer(1 * time.Second)
defer safety.Stop()
select {
case handlerErr := <-errCh:
require.ErrorIs(t, handlerErr, context.DeadlineExceeded, "handler should report DeadlineExceeded")
case <-safety.C:
t.Fatal("timed out waiting for handler to report context state")
}
}
// TestTimeout_HandlerReturnsEarlyOnCancel verifies that handlers checking context
// can return early, making the overall request faster than the handler's work time.
func TestTimeout_HandlerReturnsEarlyOnCancel(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/early-return", New(func(c fiber.Ctx) error {
// Handler that would take 500ms but checks context
for i := 0; i < 50; i++ {
select {
case <-c.Context().Done():
return c.Context().Err()
case <-time.After(10 * time.Millisecond):
// Continue work
}
}
return c.SendString("completed")
}, Config{Timeout: 30 * time.Millisecond}))
req := httptest.NewRequest(fiber.MethodGet, "/early-return", http.NoBody)
start := time.Now()
resp, err := app.Test(req)
elapsed := time.Since(start)
require.NoError(t, err)
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode)
// Should complete much faster than 500ms because handler checks context
require.Less(t, elapsed, 100*time.Millisecond)
}
// TestTimeout_CustomError tests that returning a user-defined error is also treated as a timeout.
func TestTimeout_CustomError(t *testing.T) {
t.Parallel()
app := fiber.New()
// This handler sleeps 50ms and returns errCustomTimeout if canceled.
app.Get("/custom", New(func(c fiber.Ctx) error {
// Sleep might time out, or might return early. If the context is canceled,
// we treat errCustomTimeout as a 'timeout-like' condition.
if err := sleepWithContext(c.Context(), 200*time.Millisecond, errCustomTimeout); err != nil {
return fmt.Errorf("wrapped: %w", err)
}
return c.SendString("Should never get here")
}, Config{Timeout: 50 * time.Millisecond, Errors: []error{errCustomTimeout}}))
req := httptest.NewRequest(fiber.MethodGet, "/custom", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 for custom timeout error")
}
// TestTimeout_UnmatchedError checks that if the handler returns an error
// that is neither a deadline exceeded nor a custom 'timeout' error, it is
// propagated as a regular 500 (internal server error).
func TestTimeout_UnmatchedError(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/unmatched", New(func(_ fiber.Ctx) error {
return errUnrelated // Not in the custom error list
}, Config{Timeout: 100 * time.Millisecond, Errors: []error{errCustomTimeout}}))
req := httptest.NewRequest(fiber.MethodGet, "/unmatched", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode,
"Expected 500 because the error is not recognized as a timeout error")
}
// TestTimeout_ZeroDuration tests the edge case where the timeout is set to zero.
// Usually this means the request can never exceed a 'deadline' – effectively no timeout.
func TestTimeout_ZeroDuration(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/zero", New(func(c fiber.Ctx) error {
// Sleep 50ms, but there's no real 'deadline' since zero-timeout.
time.Sleep(50 * time.Millisecond)
return c.SendString("No timeout used")
}))
req := httptest.NewRequest(fiber.MethodGet, "/zero", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK with zero timeout")
}
// TestTimeout_NegativeDuration ensures negative timeout values fall back to zero.
func TestTimeout_NegativeDuration(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/negative", New(func(c fiber.Ctx) error {
time.Sleep(50 * time.Millisecond)
return c.SendString("No timeout used")
}, Config{Timeout: -100 * time.Millisecond}))
req := httptest.NewRequest(fiber.MethodGet, "/negative", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err, "app.Test(req) should not fail")
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK with zero timeout")
}
// TestTimeout_CustomHandler ensures that a custom handler runs on timeout.
func TestTimeout_CustomHandler(t *testing.T) {
t.Parallel()
app := fiber.New()
var called atomic.Int32
app.Get("/custom-handler", New(func(c fiber.Ctx) error {
if err := sleepWithContext(c.Context(), 100*time.Millisecond, context.DeadlineExceeded); err != nil {
return err
}
return c.SendString("should not reach")
}, Config{
Timeout: 20 * time.Millisecond,
OnTimeout: func(c fiber.Ctx) error {
called.Add(1)
return c.Status(408).JSON(fiber.Map{"error": "timeout"})
},
}))
req := httptest.NewRequest(fiber.MethodGet, "/custom-handler", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode)
require.Equal(t, int32(1), called.Load())
body, readErr := io.ReadAll(resp.Body)
require.NoError(t, readErr)
require.JSONEq(t, `{"error":"timeout"}`, string(body))
}
// TestTimeout_PanicInHandler verifies that panics in the handler return 500.
func TestTimeout_PanicInHandler(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/panic", New(func(_ fiber.Ctx) error {
panic("test panic")
}, Config{Timeout: 100 * time.Millisecond}))
req := httptest.NewRequest(fiber.MethodGet, "/panic", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
// Panic in handler results in 500 Internal Server Error
require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
}
// TestIsTimeoutError_DeadlineExceeded ensures context.DeadlineExceeded triggers timeout.
func TestIsTimeoutError_DeadlineExceeded(t *testing.T) {
t.Parallel()
require.True(t, isTimeoutError(context.DeadlineExceeded, nil))
require.True(t, isTimeoutError(fmt.Errorf("wrap: %w", context.DeadlineExceeded), nil))
}
// TestIsTimeoutError_CustomErrors verifies custom errors are detected.
func TestIsTimeoutError_CustomErrors(t *testing.T) {
t.Parallel()
customErr := errors.New("custom timeout")
require.True(t, isTimeoutError(customErr, []error{customErr}))
require.True(t, isTimeoutError(fmt.Errorf("wrap: %w", customErr), []error{customErr}))
require.False(t, isTimeoutError(errUnrelated, []error{customErr}))
}
// TestIsTimeoutError_WithOnTimeout verifies that custom OnTimeout is called for custom errors.
func TestIsTimeoutError_WithOnTimeout(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
called := false
cfg := Config{
Timeout: 100 * time.Millisecond,
Errors: []error{errCustomTimeout},
OnTimeout: func(_ fiber.Ctx) error {
called = true
return errors.New("handled")
},
}
// Test via full middleware to ensure OnTimeout is called
handler := New(func(_ fiber.Ctx) error {
return fmt.Errorf("wrap: %w", errCustomTimeout)
}, cfg)
err := handler(ctx)
require.True(t, called)
require.EqualError(t, err, "handled")
}
// TestTimeout_ImmediateReturn verifies that the middleware returns immediately on timeout
// without waiting for the handler to finish (using Abandon mechanism).
func TestTimeout_ImmediateReturn(t *testing.T) {
t.Parallel()
app := fiber.New()
handlerStarted := make(chan struct{})
handlerDone := make(chan struct{})
app.Get("/immediate", New(func(_ fiber.Ctx) error {
close(handlerStarted)
// Handler takes 500ms but middleware should return after 20ms
time.Sleep(500 * time.Millisecond)
close(handlerDone)
return nil
}, Config{Timeout: 20 * time.Millisecond}))
req := httptest.NewRequest(fiber.MethodGet, "/immediate", http.NoBody)
start := time.Now()
resp, err := app.Test(req)
elapsed := time.Since(start)
require.NoError(t, err)
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode)
// Middleware should return immediately after timeout, not wait 500ms
require.Less(t, elapsed, 100*time.Millisecond, "middleware should return immediately on timeout")
// Wait for handler to verify it was abandoned properly
<-handlerStarted
select {
case <-handlerDone:
// Handler finished - cleanup goroutine should have released context
case <-time.After(1 * time.Second):
t.Log("Handler still running (expected for abandoned context)")
}
}
// TestTimeout_PanicAfterTimeout ensures panics after a timeout are handled.
func TestTimeout_PanicAfterTimeout(t *testing.T) {
t.Parallel()
app := fiber.New()
panicDone := make(chan struct{})
app.Get("/panic-after-timeout", New(func(c fiber.Ctx) error {
<-c.Context().Done()
defer close(panicDone)
panic("panic after timeout")
}, Config{Timeout: 20 * time.Millisecond}))
req := httptest.NewRequest(fiber.MethodGet, "/panic-after-timeout", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
// With immediate return, we get 408 (not 500) because panic happens after middleware returned
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode)
// Wait for panic to occur and be handled by cleanup goroutine
select {
case <-panicDone:
case <-time.After(200 * time.Millisecond):
t.Fatal("panic did not occur")
}
}
// TestTimeout_Next verifies the Next function skips the middleware.
func TestTimeout_Next(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/skip", New(func(c fiber.Ctx) error {
time.Sleep(100 * time.Millisecond)
return c.SendString("OK")
}, Config{
Timeout: 10 * time.Millisecond,
Next: func(_ fiber.Ctx) bool {
return true // Always skip
},
}))
req := httptest.NewRequest(fiber.MethodGet, "/skip", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode, "Middleware should be skipped")
}
// TestTimeout_ContextCleanup verifies that the context is properly released
// after the handler finishes (even after timeout).
func TestTimeout_ContextCleanup(t *testing.T) {
t.Parallel()
app := fiber.New()
handlerDone := make(chan struct{})
app.Get("/cleanup", New(func(c fiber.Ctx) error {
defer close(handlerDone)
<-c.Context().Done()
// Small delay to simulate cleanup
time.Sleep(50 * time.Millisecond)
return nil
}, Config{Timeout: 20 * time.Millisecond}))
req := httptest.NewRequest(fiber.MethodGet, "/cleanup", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode)
// Wait for handler to finish - cleanup goroutine should release context
select {
case <-handlerDone:
// Give cleanup goroutine time to run
time.Sleep(20 * time.Millisecond)
case <-time.After(200 * time.Millisecond):
t.Fatal("handler did not finish")
}
}
// TestTimeout_AbandonMechanism verifies the Abandon mechanism works correctly.
func TestTimeout_AbandonMechanism(t *testing.T) {
t.Parallel()
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
t.Cleanup(ctx.ForceRelease)
// Initially not abandoned
require.False(t, ctx.IsAbandoned())
// Abandon it
ctx.Abandon()
require.True(t, ctx.IsAbandoned())
// ReleaseCtx should be a no-op when abandoned
app.ReleaseCtx(ctx)
require.True(t, ctx.IsAbandoned(), "ReleaseCtx should not release abandoned context")
// Note: We intentionally do NOT test ForceRelease here.
// In the timeout middleware, abandoned contexts are NOT released back to the pool
// to avoid race conditions with requestHandler. This is the same approach
// fasthttp uses for timed-out RequestCtx objects.
}
================================================
FILE: mount.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"sort"
"sync"
"sync/atomic"
"github.com/gofiber/utils/v2"
)
// Put fields related to mounting.
type mountFields struct {
// Mounted and main apps
appList map[string]*App
// Prefix of app if it was mounted
mountPath string
// Ordered keys of apps (sorted by key length for Render)
appListKeys []string
// check added routes of sub-apps
subAppsRoutesAdded sync.Once
// check mounted sub-apps
subAppsProcessed sync.Once
}
// Create empty mountFields instance
func newMountFields(app *App) *mountFields {
return &mountFields{
appList: map[string]*App{"": app},
appListKeys: make([]string, 0),
}
}
// Mount attaches another app instance as a sub-router along a routing path.
// It's very useful to split up a large API as many independent routers and
// compose them as a single service using Mount. The fiber's error handler and
// any of the fiber's sub apps are added to the application's error handlers
// to be invoked on errors that happen within the prefix route.
func (app *App) mount(prefix string, subApp *App) Router {
prefix = utils.TrimRight(prefix, '/')
if prefix == "" {
prefix = "/"
}
app.mutex.Lock()
// Support for configs of mounted-apps and sub-mounted-apps
for mountedPrefixes, subApp := range subApp.mountFields.appList {
path := getGroupPath(prefix, mountedPrefixes)
subApp.mountFields.mountPath = path
app.mountFields.appList[path] = subApp
}
app.mutex.Unlock()
// register mounted group
mountGroup := &Group{Prefix: prefix, app: subApp}
app.register([]string{methodUse}, prefix, mountGroup)
// Execute onMount hooks
if err := subApp.hooks.executeOnMountHooks(app); err != nil {
panic(err)
}
return app
}
// Mount attaches another app instance as a sub-router along a routing path.
// It's very useful to split up a large API as many independent routers and
// compose them as a single service using Mount.
func (grp *Group) mount(prefix string, subApp *App) Router {
groupPath := getGroupPath(grp.Prefix, prefix)
groupPath = utils.TrimRight(groupPath, '/')
if groupPath == "" {
groupPath = "/"
}
grp.app.mutex.Lock()
// Support for configs of mounted-apps and sub-mounted-apps
for mountedPrefixes, subApp := range subApp.mountFields.appList {
path := getGroupPath(groupPath, mountedPrefixes)
subApp.mountFields.mountPath = path
grp.app.mountFields.appList[path] = subApp
}
grp.app.mutex.Unlock()
// register mounted group
mountGroup := &Group{Prefix: groupPath, app: subApp}
grp.app.register([]string{methodUse}, groupPath, mountGroup)
// Execute onMount hooks
if err := subApp.hooks.executeOnMountHooks(grp.app); err != nil {
panic(err)
}
return grp
}
// MountPath returns the route pattern where the current app instance was mounted as a sub-application.
func (app *App) MountPath() string {
return app.mountFields.mountPath
}
// hasMountedApps Checks if there are any mounted apps in the current application.
func (app *App) hasMountedApps() bool {
return len(app.mountFields.appList) > 1
}
// mountStartupProcess Handles the startup process of mounted apps by appending sub-app routes, generating app list keys, and processing sub-app routes.
func (app *App) mountStartupProcess() {
if app.hasMountedApps() {
// add routes of sub-apps
app.mountFields.subAppsProcessed.Do(func() {
app.appendSubAppLists(app.mountFields.appList)
app.generateAppListKeys()
})
// adds the routes of the sub-apps to the current application.
app.mountFields.subAppsRoutesAdded.Do(func() {
app.processSubAppsRoutes()
})
}
}
// generateAppListKeys generates app list keys for Render, should work after appendSubAppLists
func (app *App) generateAppListKeys() {
for key := range app.mountFields.appList {
app.mountFields.appListKeys = append(app.mountFields.appListKeys, key)
}
sort.Slice(app.mountFields.appListKeys, func(i, j int) bool {
return len(app.mountFields.appListKeys[i]) < len(app.mountFields.appListKeys[j])
})
}
// appendSubAppLists supports nested for sub apps
func (app *App) appendSubAppLists(appList map[string]*App, parent ...string) {
// Optimize: Cache parent prefix
parentPrefix := ""
if len(parent) > 0 {
parentPrefix = parent[0]
}
for prefix, subApp := range appList {
// skip real app
if prefix == "" {
continue
}
if parentPrefix != "" {
prefix = getGroupPath(parentPrefix, prefix)
}
if _, ok := app.mountFields.appList[prefix]; !ok {
app.mountFields.appList[prefix] = subApp
}
// The first element of appList is always the app itself. If there are no other sub apps, we should skip appending nested apps.
if len(subApp.mountFields.appList) > 1 {
app.appendSubAppLists(subApp.mountFields.appList, prefix)
}
}
}
// processSubAppsRoutes adds routes of sub-apps recursively when the server is started
func (app *App) processSubAppsRoutes() {
for prefix, subApp := range app.mountFields.appList {
// skip real app
if prefix == "" {
continue
}
// process the inner routes
if subApp.hasMountedApps() {
subApp.mountFields.subAppsRoutesAdded.Do(func() {
subApp.processSubAppsRoutes()
})
}
}
var handlersCount uint32
// Iterate over the stack of the parent app
for m := range app.stack {
// Iterate over each route in the stack
stackLen := len(app.stack[m])
for i := 0; i < stackLen; i++ {
route := app.stack[m][i]
// Check if the route has a mounted app
if !route.mount {
if !route.use || (route.use && m == 0) {
handlersCount += uint32(len(route.Handlers)) //nolint:gosec // G115 - handler count is always small
}
continue
}
// Create a slice to hold the sub-app's routes
subRoutes := make([]*Route, len(route.group.app.stack[m]))
// Iterate over the sub-app's routes
for j, subAppRoute := range route.group.app.stack[m] {
// Clone the sub-app's route
subAppRouteClone := app.copyRoute(subAppRoute)
// Add the parent route's path as a prefix to the sub-app's route
app.addPrefixToRoute(route.path, subAppRouteClone)
// Add the cloned sub-app's route to the slice of sub-app routes
subRoutes[j] = subAppRouteClone
}
// Insert the sub-app's routes into the parent app's stack
newStack := make([]*Route, len(app.stack[m])+len(subRoutes)-1)
copy(newStack[:i], app.stack[m][:i])
copy(newStack[i:i+len(subRoutes)], subRoutes)
copy(newStack[i+len(subRoutes):], app.stack[m][i+1:])
app.stack[m] = newStack
i--
// Mark the parent app's routes as refreshed
app.routesRefreshed = true
// update stackLen after appending subRoutes to app.stack[m]
stackLen = len(app.stack[m])
}
}
atomic.StoreUint32(&app.handlersCount, handlersCount)
}
================================================
FILE: mount_test.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
// go test -run Test_App_Mount
func Test_App_Mount(t *testing.T) {
t.Parallel()
micro := New()
micro.Get("/doe", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
app := New()
app.Use("/john", micro)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/john/doe", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, uint32(2), app.handlersCount)
}
func Test_App_Mount_RootPath_Nested(t *testing.T) {
t.Parallel()
app := New()
dynamic := New()
apiserver := New()
apiroutes := apiserver.Group("/v1")
apiroutes.Get("/home", func(c Ctx) error {
return c.SendString("home")
})
dynamic.Use("/api", apiserver)
app.Use("/", dynamic)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/v1/home", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, uint32(2), app.handlersCount)
}
// go test -run Test_App_Mount_Nested
func Test_App_Mount_Nested(t *testing.T) {
t.Parallel()
app := New()
one := New()
two := New()
three := New()
two.Use("/three", three)
app.Use("/one", one)
one.Use("/two", two)
one.Get("/doe", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
two.Get("/nested", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
three.Get("/test", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/one/doe", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(MethodGet, "/one/two/nested", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
resp, err = app.Test(httptest.NewRequest(MethodGet, "/one/two/three/test", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, uint32(6), app.handlersCount)
}
// go test -run Test_App_Mount_Express_Behavior
func Test_App_Mount_Express_Behavior(t *testing.T) {
t.Parallel()
createTestHandler := func(body string) func(c Ctx) error {
return func(c Ctx) error {
return c.SendString(body)
}
}
testEndpoint := func(app *App, route, expectedBody string, expectedStatusCode int) {
resp, err := app.Test(httptest.NewRequest(MethodGet, route, http.NoBody))
require.NoError(t, err, "app.Test(req)")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, expectedStatusCode, resp.StatusCode, "Status code")
require.Equal(t, expectedBody, string(body), "Unexpected response body")
}
app := New()
subApp := New()
// app setup
subApp.Get("/hello", createTestHandler("subapp hello!"))
subApp.Get("/world", createTestHandler("subapp world!")) // <- wins
app.Get("/hello", createTestHandler("app hello!")) // <- wins
app.Use("/", subApp) // <- subApp registration
app.Get("/world", createTestHandler("app world!"))
app.Get("/bar", createTestHandler("app bar!"))
subApp.Get("/bar", createTestHandler("subapp bar!")) // <- wins
subApp.Get("/foo", createTestHandler("subapp foo!")) // <- wins
app.Get("/foo", createTestHandler("app foo!"))
// 404 Handler
app.Use(func(c Ctx) error {
return c.SendStatus(StatusNotFound)
})
// expectation check
testEndpoint(app, "/world", "subapp world!", StatusOK)
testEndpoint(app, "/hello", "app hello!", StatusOK)
testEndpoint(app, "/bar", "subapp bar!", StatusOK)
testEndpoint(app, "/foo", "subapp foo!", StatusOK)
testEndpoint(app, "/unknown", ErrNotFound.Message, StatusNotFound)
require.Equal(t, uint32(17), app.handlersCount)
}
// go test -run Test_App_Mount_RoutePositions
func Test_App_Mount_RoutePositions(t *testing.T) {
t.Parallel()
testEndpoint := func(app *App, route, expectedBody string) {
resp, err := app.Test(httptest.NewRequest(MethodGet, route, http.NoBody))
require.NoError(t, err, "app.Test(req)")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, StatusOK, resp.StatusCode, "Status code")
require.Equal(t, expectedBody, string(body), "Unexpected response body")
}
app := New()
subApp1 := New()
subApp2 := New()
// app setup
{
app.Use(func(c Ctx) error {
// set initial value
c.Locals("world", "world")
return c.Next()
})
app.Use("/subApp1", subApp1)
app.Use(func(c Ctx) error {
return c.Next()
})
app.Get("/bar", func(c Ctx) error {
return c.SendString("ok")
})
app.Use(func(c Ctx) error {
// is overwritten when the positioning is not correct
c.Locals("world", "hello")
return c.Next()
})
methods := subApp2.Group("/subApp2")
methods.Get("/world", func(c Ctx) error {
v, ok := c.Locals("world").(string)
if !ok {
panic("unexpected data type")
}
return c.SendString(v)
})
app.Use("", subApp2)
}
testEndpoint(app, "/subApp2/world", "hello")
routeStackGET := app.Stack()[0]
require.True(t, routeStackGET[0].use)
require.Equal(t, "/", routeStackGET[0].path)
require.True(t, routeStackGET[1].use)
require.Equal(t, "/", routeStackGET[1].path)
require.False(t, routeStackGET[2].use)
require.Equal(t, "/bar", routeStackGET[2].path)
require.True(t, routeStackGET[3].use)
require.Equal(t, "/", routeStackGET[3].path)
require.False(t, routeStackGET[4].use)
require.Equal(t, "/subapp2/world", routeStackGET[4].path)
require.Len(t, routeStackGET, 5)
}
// go test -run Test_App_MountPath
func Test_App_MountPath(t *testing.T) {
t.Parallel()
app := New()
one := New()
two := New()
three := New()
two.Use("/three", three)
one.Use("/two", two)
app.Use("/one", one)
require.Equal(t, "/one", one.MountPath())
require.Equal(t, "/one/two", two.MountPath())
require.Equal(t, "/one/two/three", three.MountPath())
require.Empty(t, app.MountPath())
}
func Test_App_ErrorHandler_GroupMount(t *testing.T) {
t.Parallel()
micro := New(Config{
ErrorHandler: func(c Ctx, err error) error {
require.Equal(t, "0: GET error", err.Error())
return c.Status(500).SendString("1: custom error")
},
})
micro.Get("/doe", func(_ Ctx) error {
return errors.New("0: GET error")
})
app := New()
v1 := app.Group("/v1")
v1.Use("/john", micro)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", http.NoBody))
testErrorResponse(t, err, resp, "1: custom error")
}
func Test_App_ErrorHandler_GroupMountRootLevel(t *testing.T) {
t.Parallel()
micro := New(Config{
ErrorHandler: func(c Ctx, err error) error {
require.Equal(t, "0: GET error", err.Error())
return c.Status(500).SendString("1: custom error")
},
})
micro.Get("/john/doe", func(_ Ctx) error {
return errors.New("0: GET error")
})
app := New()
v1 := app.Group("/v1")
v1.Use("/", micro)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", http.NoBody))
testErrorResponse(t, err, resp, "1: custom error")
}
// go test -run Test_App_Group_Mount
func Test_App_Group_Mount(t *testing.T) {
t.Parallel()
micro := New()
micro.Get("/doe", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
app := New()
v1 := app.Group("/v1")
v1.Use("/john", micro)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
require.Equal(t, uint32(2), app.handlersCount)
}
func Test_App_UseParentErrorHandler(t *testing.T) {
t.Parallel()
app := New(Config{
ErrorHandler: func(ctx Ctx, _ error) error {
return ctx.Status(500).SendString("hi, i'm a custom error")
},
})
fiber := New()
fiber.Get("/", func(_ Ctx) error {
return errors.New("something happened")
})
app.Use("/api", fiber)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", http.NoBody))
testErrorResponse(t, err, resp, "hi, i'm a custom error")
}
func Test_App_UseMountedErrorHandler(t *testing.T) {
t.Parallel()
app := New()
fiber := New(Config{
ErrorHandler: func(c Ctx, _ error) error {
return c.Status(500).SendString("hi, i'm a custom error")
},
})
fiber.Get("/", func(_ Ctx) error {
return errors.New("something happened")
})
app.Use("/api", fiber)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", http.NoBody))
testErrorResponse(t, err, resp, "hi, i'm a custom error")
}
func Test_App_UseMountedErrorHandlerRootLevel(t *testing.T) {
t.Parallel()
app := New()
fiber := New(Config{
ErrorHandler: func(c Ctx, _ error) error {
return c.Status(500).SendString("hi, i'm a custom error")
},
})
fiber.Get("/api", func(_ Ctx) error {
return errors.New("something happened")
})
app.Use("/", fiber)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", http.NoBody))
testErrorResponse(t, err, resp, "hi, i'm a custom error")
}
func Test_App_UseMountedErrorHandlerForBestPrefixMatch(t *testing.T) {
t.Parallel()
app := New()
tsf := func(c Ctx, _ error) error {
return c.Status(200).SendString("hi, i'm a custom sub fiber error 2")
}
tripleSubFiber := New(Config{
ErrorHandler: tsf,
})
tripleSubFiber.Get("/", func(_ Ctx) error {
return errors.New("something happened")
})
sf := func(c Ctx, _ error) error {
return c.Status(200).SendString("hi, i'm a custom sub fiber error")
}
subfiber := New(Config{
ErrorHandler: sf,
})
subfiber.Get("/", func(_ Ctx) error {
return errors.New("something happened")
})
subfiber.Use("/third", tripleSubFiber)
f := func(c Ctx, _ error) error {
return c.Status(200).SendString("hi, i'm a custom error")
}
fiber := New(Config{
ErrorHandler: f,
})
fiber.Get("/", func(_ Ctx) error {
return errors.New("something happened")
})
fiber.Use("/sub", subfiber)
app.Use("/api", fiber)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub", http.NoBody))
require.NoError(t, err, "/api/sub req")
require.Equal(t, 200, resp.StatusCode, "Status code")
b, err := io.ReadAll(resp.Body)
require.NoError(t, err, "iotuil.ReadAll()")
require.Equal(t, "hi, i'm a custom sub fiber error", string(b), "Response body")
resp2, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub/third", http.NoBody))
require.NoError(t, err, "/api/sub/third req")
require.Equal(t, 200, resp2.StatusCode, "Status code")
b, err = io.ReadAll(resp2.Body)
require.NoError(t, err, "iotuil.ReadAll()")
require.Equal(t, "hi, i'm a custom sub fiber error 2", string(b), "Third fiber Response body")
}
// go test -run Test_Mount_Route_Names
func Test_Mount_Route_Names(t *testing.T) {
t.Parallel()
// create sub-app with 2 handlers:
subApp1 := New()
subApp1.Get("/users", func(c Ctx) error {
url, err := c.GetRouteURL("add-user", Map{})
require.NoError(t, err)
require.Equal(t, "/app1/users", url, "handler: app1.add-user") // the prefix is /app1 because of the mount
// if subApp1 is not mounted, expected url just /users
return nil
}).Name("get-users")
subApp1.Post("/users", func(c Ctx) error {
route := c.App().GetRoute("get-users")
require.Equal(t, MethodGet, route.Method, "handler: app1.get-users method")
require.Equal(t, "/app1/users", route.Path, "handler: app1.get-users path")
return nil
}).Name("add-user")
// create sub-app with 2 handlers inside a group:
subApp2 := New()
app2Grp := subApp2.Group("/users").Name("users.")
app2Grp.Get("", emptyHandler).Name("get")
app2Grp.Post("", emptyHandler).Name("add")
// put both sub-apps into root app
rootApp := New()
_ = rootApp.Use("/app1", subApp1)
_ = rootApp.Use("/app2", subApp2)
rootApp.startupProcess()
// take route directly from sub-app
route := subApp1.GetRoute("get-users")
require.Equal(t, MethodGet, route.Method)
require.Equal(t, "/users", route.Path)
route = subApp1.GetRoute("add-user")
require.Equal(t, MethodPost, route.Method)
require.Equal(t, "/users", route.Path)
// take route directly from sub-app with group
route = subApp2.GetRoute("users.get")
require.Equal(t, MethodGet, route.Method)
require.Equal(t, "/users", route.Path)
route = subApp2.GetRoute("users.add")
require.Equal(t, MethodPost, route.Method)
require.Equal(t, "/users", route.Path)
// take route from root app (using names of sub-apps)
route = rootApp.GetRoute("add-user")
require.Equal(t, MethodPost, route.Method)
require.Equal(t, "/app1/users", route.Path)
route = rootApp.GetRoute("users.add")
require.Equal(t, MethodPost, route.Method)
require.Equal(t, "/app2/users", route.Path)
// GetRouteURL inside handler
req := httptest.NewRequest(MethodGet, "/app1/users", http.NoBody)
resp, err := rootApp.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, StatusOK, resp.StatusCode, "Status code")
// ctx.App().GetRoute() inside handler
req = httptest.NewRequest(MethodPost, "/app1/users", http.NoBody)
resp, err = rootApp.Test(req)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, StatusOK, resp.StatusCode, "Status code")
}
// go test -run Test_Ctx_Render_Mount
func Test_Ctx_Render_Mount(t *testing.T) {
t.Parallel()
engine := &testTemplateEngine{}
err := engine.Load()
require.NoError(t, err)
sub := New(Config{
Views: engine,
})
sub.Get("/:name", func(c Ctx) error {
return c.Render("hello_world.tmpl", Map{
"Name": c.Params("name"),
})
})
app := New()
app.Use("/hello", sub)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/hello/a", http.NoBody))
require.Equal(t, StatusOK, resp.StatusCode, "Status code")
require.NoError(t, err, "app.Test(req)")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "
", string(body))
}
================================================
FILE: path.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 📄 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
// ⚠️ This path parser was inspired by https://github.com/ucarion/urlpath
// 💖 Maintained and modified for Fiber by @renewerner87
package fiber
import (
"bytes"
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"time"
"unicode"
"github.com/gofiber/utils/v2"
utilsbytes "github.com/gofiber/utils/v2/bytes"
utilsstrings "github.com/gofiber/utils/v2/strings"
"github.com/google/uuid"
)
// routeParser holds the path segments and param names
type routeParser struct {
segs []*routeSegment // the parsed segments of the route
params []string // that parameter names the parsed route
wildCardCount int // number of wildcard parameters, used internally to give the wildcard parameter its number
plusCount int // number of plus parameters, used internally to give the plus parameter its number
}
var routerParserPool = &sync.Pool{
New: func() any {
return &routeParser{}
},
}
// routeSegment holds the segment metadata
type routeSegment struct {
// const information
Const string // constant part of the route
ParamName string // name of the parameter for access to it, for wildcards and plus parameters access iterators starting with 1 are added
ComparePart string // search part to find the end of the parameter
Constraints []*Constraint // Constraint type if segment is a parameter, if not it will be set to noConstraint by default
PartCount int // how often is the search part contained in the non-param segments? -> necessary for greedy search
Length int // length of the parameter for segment, when its 0 then the length is undetermined
// future TODO: add support for optional groups "/abc(/def)?"
// parameter information
IsParam bool // Truth value that indicates whether it is a parameter or a constant part
IsGreedy bool // indicates whether the parameter is greedy or not, is used with wildcard and plus
IsOptional bool // indicates whether the parameter is optional or not
// common information
IsLast bool // shows if the segment is the last one for the route
HasOptionalSlash bool // segment has the possibility of an optional slash
}
// different special routing signs
const (
wildcardParam byte = '*' // indicates an optional greedy parameter
plusParam byte = '+' // indicates a required greedy parameter
optionalParam byte = '?' // concludes a parameter by name and makes it optional
paramStarterChar byte = ':' // start character for a parameter with name
slashDelimiter byte = '/' // separator for the route, unlike the other delimiters this character at the end can be optional
escapeChar byte = '\\' // escape character
paramConstraintStart byte = '<' // start of type constraint for a parameter
paramConstraintEnd byte = '>' // end of type constraint for a parameter
paramConstraintSeparator byte = ';' // separator of type constraints for a parameter
paramConstraintDataStart byte = '(' // start of data of type constraint for a parameter
paramConstraintDataEnd byte = ')' // end of data of type constraint for a parameter
paramConstraintDataSeparator byte = ',' // separator of data of type constraint for a parameter
)
// TypeConstraint parameter constraint types
type TypeConstraint uint16
// Constraint describes the validation rules that apply to a dynamic route
// segment when matching incoming requests.
type Constraint struct {
RegexCompiler *regexp.Regexp
Name string
Data []string
customConstraints []CustomConstraint
ID TypeConstraint
}
// CustomConstraint is an interface for custom constraints
type CustomConstraint interface {
// Name returns the name of the constraint.
// This name is used in the constraint matching.
Name() string
// Execute executes the constraint.
// It returns true if the constraint is matched and right.
// param is the parameter value to check.
// args are the constraint arguments.
Execute(param string, args ...string) bool
}
const (
noConstraint TypeConstraint = 1 << iota
intConstraint
boolConstraint
floatConstraint
alphaConstraint
datetimeConstraint
guidConstraint
minLenConstraint
maxLenConstraint
lenConstraint
betweenLenConstraint
minConstraint
maxConstraint
rangeConstraint
regexConstraint
)
const (
needOneData = minLenConstraint | maxLenConstraint | lenConstraint | minConstraint | maxConstraint | datetimeConstraint | regexConstraint
needTwoData = betweenLenConstraint | rangeConstraint
)
// list of possible parameter and segment delimiter
var (
// slash has a special role, unlike the other parameters it must not be interpreted as a parameter
routeDelimiter = []byte{slashDelimiter, '-', '.'}
// list of greedy parameters
greedyParameters = []byte{wildcardParam, plusParam}
// list of chars for the parameter recognizing
parameterStartChars = [256]bool{
wildcardParam: true,
plusParam: true,
paramStarterChar: true,
}
// list of chars of delimiters and the starting parameter name char
parameterDelimiterChars = append([]byte{paramStarterChar, escapeChar}, routeDelimiter...)
// list of chars to find the end of a parameter
parameterEndChars = [256]bool{
optionalParam: true,
paramStarterChar: true,
escapeChar: true,
slashDelimiter: true,
'-': true,
'.': true,
}
)
// RoutePatternMatch reports whether path matches the provided Fiber route pattern.
//
// Patterns use the same syntax as routes registered on an App, including
// parameters (for example `:id`), wildcards (`*`, `+`), and optional segments.
// The optional Config argument can be used to control case sensitivity and
// strict routing behavior. This helper allows checking potential matches
// without registering a route.
func RoutePatternMatch(path, pattern string, cfg ...Config) bool {
// See logic in (*Route).match and (*App).register
var ctxParams [maxParams]string
config := Config{}
if len(cfg) > 0 {
config = cfg[0]
}
if path == "" {
path = "/"
}
// Cannot have an empty pattern
if pattern == "" {
pattern = "/"
}
// Pattern always start with a '/'
if pattern[0] != '/' {
pattern = "/" + pattern
}
patternPretty := []byte(pattern)
// Case-sensitive routing, all to lowercase
if !config.CaseSensitive {
patternPretty = utilsbytes.UnsafeToLower(patternPretty)
path = utilsstrings.ToLower(path)
}
// Strict routing, remove trailing slashes
if !config.StrictRouting && len(patternPretty) > 1 {
patternPretty = utils.TrimRight(patternPretty, '/')
}
parser, _ := routerParserPool.Get().(*routeParser) //nolint:errcheck // only contains routeParser
parser.reset()
patternStr := string(patternPretty)
parser.parseRoute(patternStr)
defer routerParserPool.Put(parser)
// '*' wildcard matches any path
if (patternStr == "/" && path == "/") || patternStr == "/*" {
return true
}
// Does this route have parameters
if len(parser.params) > 0 {
if match := parser.getMatch(path, path, &ctxParams, false); match {
return true
}
}
// Check for a simple match
patternPretty = RemoveEscapeCharBytes(patternPretty)
return string(patternPretty) == path
}
func (parser *routeParser) reset() {
parser.segs = parser.segs[:0]
parser.params = parser.params[:0]
parser.wildCardCount = 0
parser.plusCount = 0
}
// parseRoute analyzes the route and divides it into segments for constant areas and parameters,
// this information is needed later when assigning the requests to the declared routes
func (parser *routeParser) parseRoute(pattern string, customConstraints ...CustomConstraint) {
var n int
var seg *routeSegment
for pattern != "" {
nextParamPosition := findNextParamPosition(pattern)
// handle the parameter part
if nextParamPosition == 0 {
n, seg = parser.analyseParameterPart(pattern, customConstraints...)
parser.params, parser.segs = append(parser.params, seg.ParamName), append(parser.segs, seg)
} else {
n, seg = parser.analyseConstantPart(pattern, nextParamPosition)
parser.segs = append(parser.segs, seg)
}
pattern = pattern[n:]
}
// mark last segment
if len(parser.segs) > 0 {
parser.segs[len(parser.segs)-1].IsLast = true
}
parser.segs = addParameterMetaInfo(parser.segs)
}
// parseRoute analyzes the route and divides it into segments for constant areas and parameters,
// this information is needed later when assigning the requests to the declared routes
func parseRoute(pattern string, customConstraints ...CustomConstraint) routeParser {
parser := routeParser{}
parser.parseRoute(pattern, customConstraints...)
// Check if the route has too many parameters
if len(parser.params) > maxParams {
panic(fmt.Sprintf("Route '%s' has %d parameters, which exceeds the maximum of %d",
pattern, len(parser.params), maxParams))
}
return parser
}
// addParameterMetaInfo add important meta information to the parameter segments
// to simplify the search for the end of the parameter
func addParameterMetaInfo(segs []*routeSegment) []*routeSegment {
var comparePart string
segLen := len(segs)
// loop from end to begin
for i := segLen - 1; i >= 0; i-- {
// set the compare part for the parameter
if segs[i].IsParam {
// important for finding the end of the parameter
segs[i].ComparePart = RemoveEscapeChar(comparePart)
} else {
comparePart = segs[i].Const
if len(comparePart) > 1 {
comparePart = utils.TrimRight(comparePart, slashDelimiter)
}
}
}
// loop from beginning to end
for i := range segLen {
// check how often the compare part is in the following const parts
if segs[i].IsParam {
// check if parameter segments are directly after each other;
// when neither this parameter nor the next parameter are greedy, we only want one character
if segLen > i+1 && !segs[i].IsGreedy && segs[i+1].IsParam && !segs[i+1].IsGreedy {
segs[i].Length = 1
}
if segs[i].ComparePart == "" {
continue
}
for j := i + 1; j <= len(segs)-1; j++ {
if !segs[j].IsParam {
// count is important for the greedy match
segs[i].PartCount += strings.Count(segs[j].Const, segs[i].ComparePart)
}
}
// check if the end of the segment is an optional slash and then if the segment is optional or the last one
} else if segs[i].Const[len(segs[i].Const)-1] == slashDelimiter && (segs[i].IsLast || (segLen > i+1 && segs[i+1].IsOptional)) {
segs[i].HasOptionalSlash = true
}
}
return segs
}
// findNextParamPosition search for the next possible parameter start position
func findNextParamPosition(pattern string) int {
// Find the first parameter position
next := -1
for i := range pattern {
if parameterStartChars[pattern[i]] && (i == 0 || pattern[i-1] != escapeChar) {
next = i
break
}
}
if next > 0 && pattern[next] != wildcardParam {
// checking the found parameterStartChar is a cluster
for i := next + 1; i < len(pattern); i++ {
if !parameterStartChars[pattern[i]] {
return i - 1
}
}
return len(pattern) - 1
}
return next
}
// analyseConstantPart find the end of the constant part and create the route segment
func (*routeParser) analyseConstantPart(pattern string, nextParamPosition int) (int, *routeSegment) {
// handle the constant part
processedPart := pattern
if nextParamPosition != -1 {
// remove the constant part until the parameter
processedPart = pattern[:nextParamPosition]
}
constPart := RemoveEscapeChar(processedPart)
return len(processedPart), &routeSegment{
Const: constPart,
Length: len(constPart),
}
}
// analyseParameterPart find the parameter end and create the route segment
func (parser *routeParser) analyseParameterPart(pattern string, customConstraints ...CustomConstraint) (int, *routeSegment) {
isWildCard := pattern[0] == wildcardParam
isPlusParam := pattern[0] == plusParam
paramEndPosition := 0
paramConstraintStartPosition := -1
paramConstraintEndPosition := -1
// handle wildcard end
if !isWildCard && !isPlusParam {
paramEndPosition = -1
search := pattern[1:]
for i := range search {
if paramConstraintStartPosition == -1 && search[i] == paramConstraintStart && (i == 0 || search[i-1] != escapeChar) {
paramConstraintStartPosition = i + 1
continue
}
if paramConstraintEndPosition == -1 && search[i] == paramConstraintEnd && (i == 0 || search[i-1] != escapeChar) {
paramConstraintEndPosition = i + 1
continue
}
if parameterEndChars[search[i]] {
if (paramConstraintStartPosition == -1 && paramConstraintEndPosition == -1) ||
(paramConstraintStartPosition != -1 && paramConstraintEndPosition != -1) {
paramEndPosition = i
break
}
}
}
switch {
case paramEndPosition == -1:
paramEndPosition = len(pattern) - 1
case bytes.IndexByte(parameterDelimiterChars, pattern[paramEndPosition+1]) == -1:
paramEndPosition++
default:
// do nothing
}
}
// cut params part
processedPart := pattern[0 : paramEndPosition+1]
n := paramEndPosition + 1
paramName := RemoveEscapeChar(GetTrimmedParam(processedPart))
// Check has constraint
var constraints []*Constraint
if hasConstraint := paramConstraintStartPosition != -1 && paramConstraintEndPosition != -1; hasConstraint {
constraintString := pattern[paramConstraintStartPosition+1 : paramConstraintEndPosition]
userConstraints := splitNonEscaped(constraintString, paramConstraintSeparator)
constraints = make([]*Constraint, 0, len(userConstraints))
for _, c := range userConstraints {
start := findNextNonEscapedCharPosition(c, paramConstraintDataStart)
end := strings.LastIndexByte(c, paramConstraintDataEnd)
// Assign constraint
if start != -1 && end != -1 {
constraint := &Constraint{
ID: getParamConstraintType(c[:start]),
Name: c[:start],
customConstraints: customConstraints,
}
// remove escapes from data
if constraint.ID != regexConstraint {
constraint.Data = splitNonEscaped(c[start+1:end], paramConstraintDataSeparator)
if len(constraint.Data) == 1 {
constraint.Data[0] = RemoveEscapeChar(constraint.Data[0])
} else if len(constraint.Data) == 2 { // This is fine, we simply expect two parts
constraint.Data[0] = RemoveEscapeChar(constraint.Data[0])
constraint.Data[1] = RemoveEscapeChar(constraint.Data[1])
}
}
// Precompile regex if has regex constraint
if constraint.ID == regexConstraint {
constraint.Data = []string{c[start+1 : end]}
constraint.RegexCompiler = regexp.MustCompile(constraint.Data[0])
}
constraints = append(constraints, constraint)
} else {
constraints = append(constraints, &Constraint{
ID: getParamConstraintType(c),
Data: []string{},
Name: c,
customConstraints: customConstraints,
})
}
}
paramName = RemoveEscapeChar(GetTrimmedParam(pattern[0:paramConstraintStartPosition]))
}
// add access iterator to wildcard and plus
if isWildCard {
parser.wildCardCount++
paramName += strconv.Itoa(parser.wildCardCount)
} else if isPlusParam {
parser.plusCount++
paramName += strconv.Itoa(parser.plusCount)
}
segment := &routeSegment{
ParamName: paramName,
IsParam: true,
IsOptional: isWildCard || pattern[paramEndPosition] == optionalParam,
IsGreedy: isWildCard || isPlusParam,
}
if len(constraints) > 0 {
segment.Constraints = constraints
}
return n, segment
}
// findNextNonEscapedCharPosition searches the next char position and skips the escaped characters
func findNextNonEscapedCharPosition(search string, char byte) int {
for i := 0; i < len(search); i++ {
if search[i] == char && (i == 0 || search[i-1] != escapeChar) {
return i
}
}
return -1
}
// splitNonEscaped slices s into all substrings separated by sep and returns a slice of the substrings between those separators
// This function also takes a care of escape char when splitting.
func splitNonEscaped(s string, sep byte) []string {
var result []string
i := findNextNonEscapedCharPosition(s, sep)
for i > -1 {
result = append(result, s[:i])
s = s[i+1:]
i = findNextNonEscapedCharPosition(s, sep)
}
return append(result, s)
}
func hasPartialMatchBoundary(path string, matchedLength int) bool {
if matchedLength < 0 || matchedLength > len(path) {
return false
}
if matchedLength == len(path) {
return true
}
if matchedLength == 0 {
return false
}
if path[matchedLength-1] == slashDelimiter {
return true
}
if matchedLength < len(path) && path[matchedLength] == slashDelimiter {
return true
}
return false
}
// getMatch parses the passed url and tries to match it against the route segments and determine the parameter positions
func (parser *routeParser) getMatch(detectionPath, path string, params *[maxParams]string, partialCheck bool) bool { //nolint:revive // Accepting a bool param is fine here
originalDetectionPath := detectionPath
var i, paramsIterator, partLen int
for _, segment := range parser.segs {
partLen = len(detectionPath)
// check const segment
if !segment.IsParam {
i = segment.Length
// is optional part or the const part must match with the given string
// check if the end of the segment is an optional slash
if segment.HasOptionalSlash && partLen == i-1 && detectionPath == segment.Const[:i-1] {
i--
} else if i > partLen || detectionPath[:i] != segment.Const {
return false
}
} else {
// determine parameter length
i = findParamLen(detectionPath, segment)
if !segment.IsOptional && i == 0 {
return false
}
// take over the params positions
params[paramsIterator] = path[:i]
if !segment.IsOptional || i != 0 {
// check constraint
for _, c := range segment.Constraints {
if matched := c.CheckConstraint(params[paramsIterator]); !matched {
return false
}
}
}
paramsIterator++
}
// reduce founded part from the string
if partLen > 0 {
detectionPath, path = detectionPath[i:], path[i:]
}
}
if detectionPath != "" {
if !partialCheck {
return false
}
consumedLength := len(originalDetectionPath) - len(detectionPath)
if !hasPartialMatchBoundary(originalDetectionPath, consumedLength) {
return false
}
}
return true
}
// findParamLen for the expressjs wildcard behavior (right to left greedy)
// look at the other segments and take what is left for the wildcard from right to left
func findParamLen(s string, segment *routeSegment) int {
if segment.IsLast {
return findParamLenForLastSegment(s, segment)
}
if segment.Length != 0 && len(s) >= segment.Length {
return segment.Length
} else if segment.IsGreedy {
// Search the parameters until the next constant part
// special logic for greedy params
searchCount := strings.Count(s, segment.ComparePart)
if searchCount > 1 {
return findGreedyParamLen(s, searchCount, segment)
}
}
if len(segment.ComparePart) == 1 {
if constPosition := strings.IndexByte(s, segment.ComparePart[0]); constPosition != -1 {
return constPosition
}
} else if constPosition := strings.Index(s, segment.ComparePart); constPosition != -1 {
// if the compare part was found, but contains a slash although this part is not greedy, then it must not match
// example: /api/:param/fixedEnd -> path: /api/123/456/fixedEnd = no match , /api/123/fixedEnd = match
if !segment.IsGreedy && strings.IndexByte(s[:constPosition], slashDelimiter) != -1 {
return 0
}
return constPosition
}
return len(s)
}
// findParamLenForLastSegment get the length of the parameter if it is the last segment
func findParamLenForLastSegment(s string, seg *routeSegment) int {
if !seg.IsGreedy {
if i := strings.IndexByte(s, slashDelimiter); i != -1 {
return i
}
}
return len(s)
}
// findGreedyParamLen get the length of the parameter for greedy segments from right to left
func findGreedyParamLen(s string, searchCount int, segment *routeSegment) int {
// check all from right to left segments
for i := segment.PartCount; i > 0 && searchCount > 0; i-- {
searchCount--
constPosition := strings.LastIndex(s, segment.ComparePart)
if constPosition == -1 {
break
}
s = s[:constPosition]
}
return len(s)
}
// GetTrimmedParam trims the ':' & '?' from a string
func GetTrimmedParam(param string) string {
start := 0
end := len(param)
if end == 0 || param[start] != paramStarterChar { // is not a param
return param
}
start++
if param[end-1] == optionalParam { // is ?
end--
}
return param[start:end]
}
// RemoveEscapeChar removes escape characters
func RemoveEscapeChar(word string) string {
// Fast path: check if there are any escape characters first
escapeIdx := strings.IndexByte(word, '\\')
if escapeIdx == -1 {
return word // No escape chars, return original string without allocation
}
// Slow path: copy and remove escape characters
b := []byte(word)
dst := escapeIdx
for src := escapeIdx + 1; src < len(b); src++ {
if b[src] != '\\' {
b[dst] = b[src]
dst++
}
}
return string(b[:dst])
}
// RemoveEscapeCharBytes removes escape characters
func RemoveEscapeCharBytes(word []byte) []byte {
dst := 0
for src := range word {
if word[src] != '\\' {
word[dst] = word[src]
dst++
}
}
return word[:dst]
}
func getParamConstraintType(constraintPart string) TypeConstraint {
switch constraintPart {
case ConstraintInt:
return intConstraint
case ConstraintBool:
return boolConstraint
case ConstraintFloat:
return floatConstraint
case ConstraintAlpha:
return alphaConstraint
case ConstraintGUID:
return guidConstraint
case ConstraintMinLen, ConstraintMinLenLower:
return minLenConstraint
case ConstraintMaxLen, ConstraintMaxLenLower:
return maxLenConstraint
case ConstraintLen:
return lenConstraint
case ConstraintBetweenLen, ConstraintBetweenLenLower:
return betweenLenConstraint
case ConstraintMin:
return minConstraint
case ConstraintMax:
return maxConstraint
case ConstraintRange:
return rangeConstraint
case ConstraintDatetime:
return datetimeConstraint
case ConstraintRegex:
return regexConstraint
default:
return noConstraint
}
}
// CheckConstraint validates if a param matches the given constraint
// Returns true if the param passes the constraint check, false otherwise
func (c *Constraint) CheckConstraint(param string) bool {
// First check if there's a custom constraint with the same name
// This allows custom constraints to override built-in constraints
for _, cc := range c.customConstraints {
if cc.Name() == c.Name {
return cc.Execute(param, c.Data...)
}
}
var (
err error
num int
)
// Validate constraint has required data
if c.ID&needOneData != 0 && len(c.Data) == 0 {
return false
}
if c.ID&needTwoData != 0 && len(c.Data) < 2 {
return false
}
switch c.ID {
case noConstraint:
return true
case intConstraint:
_, err = strconv.Atoi(param)
case boolConstraint:
_, err = strconv.ParseBool(param)
case floatConstraint:
_, err = strconv.ParseFloat(param, 32)
case alphaConstraint:
for _, r := range param {
if !unicode.IsLetter(r) {
return false
}
}
case guidConstraint:
_, err = uuid.Parse(param)
case minLenConstraint:
data, parseErr := strconv.Atoi(c.Data[0])
if parseErr != nil {
return false
}
if len(param) < data {
return false
}
case maxLenConstraint:
data, parseErr := strconv.Atoi(c.Data[0])
if parseErr != nil {
return false
}
if len(param) > data {
return false
}
case lenConstraint:
data, parseErr := strconv.Atoi(c.Data[0])
if parseErr != nil {
return false
}
if len(param) != data {
return false
}
case betweenLenConstraint:
data, parseErr := strconv.Atoi(c.Data[0])
if parseErr != nil {
return false
}
data2, parseErr := strconv.Atoi(c.Data[1])
if parseErr != nil {
return false
}
length := len(param)
if length < data || length > data2 {
return false
}
case minConstraint:
data, parseErr := strconv.Atoi(c.Data[0])
if parseErr != nil {
return false
}
num, err = strconv.Atoi(param)
if err != nil || num < data {
return false
}
case maxConstraint:
data, parseErr := strconv.Atoi(c.Data[0])
if parseErr != nil {
return false
}
num, err = strconv.Atoi(param)
if err != nil || num > data {
return false
}
case rangeConstraint:
data, parseErr := strconv.Atoi(c.Data[0])
if parseErr != nil {
return false
}
data2, parseErr := strconv.Atoi(c.Data[1])
if parseErr != nil {
return false
}
num, err = strconv.Atoi(param)
if err != nil || num < data || num > data2 {
return false
}
case datetimeConstraint:
_, err = time.Parse(c.Data[0], param)
if err != nil {
return false
}
case regexConstraint:
if c.RegexCompiler == nil {
return false
}
if match := c.RegexCompiler.MatchString(param); !match {
return false
}
default:
return false
}
return err == nil
}
================================================
FILE: path_test.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 📝 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// go test -race -run Test_Path_parseRoute
func Test_Path_parseRoute(t *testing.T) {
t.Parallel()
var rp routeParser
rp = parseRoute("/shop/product/::filter/color::color/size::size")
require.Equal(t, routeParser{
segs: []*routeSegment{
{Const: "/shop/product/:", Length: 15},
{IsParam: true, ParamName: "filter", ComparePart: "/color:", PartCount: 1},
{Const: "/color:", Length: 7},
{IsParam: true, ParamName: "color", ComparePart: "/size:", PartCount: 1},
{Const: "/size:", Length: 6},
{IsParam: true, ParamName: "size", IsLast: true},
},
params: []string{"filter", "color", "size"},
}, rp)
rp = parseRoute("/api/v1/:param/abc/*")
require.Equal(t, routeParser{
segs: []*routeSegment{
{Const: "/api/v1/", Length: 8},
{IsParam: true, ParamName: "param", ComparePart: "/abc", PartCount: 1},
{Const: "/abc/", Length: 5, HasOptionalSlash: true},
{IsParam: true, ParamName: "*1", IsGreedy: true, IsOptional: true, IsLast: true},
},
params: []string{"param", "*1"},
wildCardCount: 1,
}, rp)
rp = parseRoute("/v1/some/resource/name\\:customVerb")
require.Equal(t, routeParser{
segs: []*routeSegment{
{Const: "/v1/some/resource/name:customVerb", Length: 33, IsLast: true},
},
params: nil,
}, rp)
rp = parseRoute("/v1/some/resource/:name\\:customVerb")
require.Equal(t, routeParser{
segs: []*routeSegment{
{Const: "/v1/some/resource/", Length: 18},
{IsParam: true, ParamName: "name", ComparePart: ":customVerb", PartCount: 1},
{Const: ":customVerb", Length: 11, IsLast: true},
},
params: []string{"name"},
}, rp)
// heavy test with escaped characters
rp = parseRoute("/v1/some/resource/name\\\\:customVerb?\\?/:param/*")
require.Equal(t, routeParser{
segs: []*routeSegment{
{Const: "/v1/some/resource/name:customVerb??/", Length: 36},
{IsParam: true, ParamName: "param", ComparePart: "/", PartCount: 1},
{Const: "/", Length: 1, HasOptionalSlash: true},
{IsParam: true, ParamName: "*1", IsGreedy: true, IsOptional: true, IsLast: true},
},
params: []string{"param", "*1"},
wildCardCount: 1,
}, rp)
rp = parseRoute("/api/*/:param/:param2")
require.Equal(t, routeParser{
segs: []*routeSegment{
{Const: "/api/", Length: 5, HasOptionalSlash: true},
{IsParam: true, ParamName: "*1", IsGreedy: true, IsOptional: true, ComparePart: "/", PartCount: 2},
{Const: "/", Length: 1},
{IsParam: true, ParamName: "param", ComparePart: "/", PartCount: 1},
{Const: "/", Length: 1},
{IsParam: true, ParamName: "param2", IsLast: true},
},
params: []string{"*1", "param", "param2"},
wildCardCount: 1,
}, rp)
rp = parseRoute("/test:optional?:optional2?")
require.Equal(t, routeParser{
segs: []*routeSegment{
{Const: "/test", Length: 5},
{IsParam: true, ParamName: "optional", IsOptional: true, Length: 1},
{IsParam: true, ParamName: "optional2", IsOptional: true, IsLast: true},
},
params: []string{"optional", "optional2"},
}, rp)
rp = parseRoute("/config/+.json")
require.Equal(t, routeParser{
segs: []*routeSegment{
{Const: "/config/", Length: 8},
{IsParam: true, ParamName: "+1", IsGreedy: true, IsOptional: false, ComparePart: ".json", PartCount: 1},
{Const: ".json", Length: 5, IsLast: true},
},
params: []string{"+1"},
plusCount: 1,
}, rp)
rp = parseRoute("/api/:day.:month?.:year?")
require.Equal(t, routeParser{
segs: []*routeSegment{
{Const: "/api/", Length: 5},
{IsParam: true, ParamName: "day", IsOptional: false, ComparePart: ".", PartCount: 2},
{Const: ".", Length: 1},
{IsParam: true, ParamName: "month", IsOptional: true, ComparePart: ".", PartCount: 1},
{Const: ".", Length: 1},
{IsParam: true, ParamName: "year", IsOptional: true, IsLast: true},
},
params: []string{"day", "month", "year"},
}, rp)
rp = parseRoute("/*v1*/proxy")
require.Equal(t, routeParser{
segs: []*routeSegment{
{Const: "/", Length: 1, HasOptionalSlash: true},
{IsParam: true, ParamName: "*1", IsGreedy: true, IsOptional: true, ComparePart: "v1", PartCount: 1},
{Const: "v1", Length: 2},
{IsParam: true, ParamName: "*2", IsGreedy: true, IsOptional: true, ComparePart: "/proxy", PartCount: 1},
{Const: "/proxy", Length: 6, IsLast: true},
},
params: []string{"*1", "*2"},
wildCardCount: 2,
}, rp)
}
// go test -race -run Test_Path_matchParams
func Test_Path_matchParams(t *testing.T) {
t.Parallel()
var ctxParams [maxParams]string
testCaseFn := func(testCollection routeCaseCollection) {
parser := parseRoute(testCollection.pattern)
for _, c := range testCollection.testCases {
match := parser.getMatch(c.url, c.url, &ctxParams, c.partialCheck)
require.Equal(t, c.match, match, "route: '%s', url: '%s'", testCollection.pattern, c.url)
if match && len(c.params) > 0 {
require.Equal(t, c.params[0:len(c.params)], ctxParams[0:len(c.params)], "route: '%s', url: '%s'", testCollection.pattern, c.url)
}
}
}
for _, testCaseCollection := range routeTestCases {
testCaseFn(testCaseCollection)
}
}
// go test -race -run Test_RoutePatternMatch
func Test_RoutePatternMatch(t *testing.T) {
t.Parallel()
testCaseFn := func(pattern string, cases []routeTestCase) {
for _, c := range cases {
// skip all cases for partial checks
if c.partialCheck {
continue
}
match := RoutePatternMatch(c.url, pattern)
require.Equal(t, c.match, match, "route: '%s', url: '%s'", pattern, c.url)
}
}
for _, testCase := range routeTestCases {
testCaseFn(testCase.pattern, testCase.testCases)
}
}
func TestHasPartialMatchBoundary(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
path string
matchedLength int
expected bool
}{
{
name: "negative length",
path: "/demo",
matchedLength: -1,
expected: false,
},
{
name: "greater than length",
path: "/demo",
matchedLength: 6,
expected: false,
},
{
name: "exact match",
path: "/demo",
matchedLength: len("/demo"),
expected: true,
},
{
name: "zero length",
path: "/demo",
matchedLength: 0,
expected: false,
},
{
name: "previous rune slash",
path: "/demo/child",
matchedLength: len("/demo/"),
expected: true,
},
{
name: "next rune slash",
path: "/demo/child",
matchedLength: len("/demo"),
expected: true,
},
{
name: "no boundary",
path: "/demo/child",
matchedLength: len("/dem"),
expected: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, testCase.expected, hasPartialMatchBoundary(testCase.path, testCase.matchedLength))
})
}
}
func Test_Utils_GetTrimmedParam(t *testing.T) {
t.Parallel()
res := GetTrimmedParam("")
require.Empty(t, res)
res = GetTrimmedParam("*")
require.Equal(t, "*", res)
res = GetTrimmedParam(":param")
require.Equal(t, "param", res)
res = GetTrimmedParam(":param1?")
require.Equal(t, "param1", res)
res = GetTrimmedParam("noParam")
require.Equal(t, "noParam", res)
}
func Test_Utils_RemoveEscapeChar(t *testing.T) {
t.Parallel()
res := RemoveEscapeChar(":test\\:bla")
require.Equal(t, ":test:bla", res)
res = RemoveEscapeChar("\\abc")
require.Equal(t, "abc", res)
res = RemoveEscapeChar("noEscapeChar")
require.Equal(t, "noEscapeChar", res)
}
func Test_ConstraintCheckConstraint_InvalidMetadata(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
param string
constraint Constraint
}{
{
name: "minLen invalid metadata",
constraint: Constraint{ID: minLenConstraint, Data: []string{"abc"}},
param: "abcd",
},
{
name: "maxLen invalid metadata",
constraint: Constraint{ID: maxLenConstraint, Data: []string{"abc"}},
param: "abcd",
},
{
name: "len invalid metadata",
constraint: Constraint{ID: lenConstraint, Data: []string{"abc"}},
param: "abcd",
},
{
name: "betweenLen invalid first metadata",
constraint: Constraint{ID: betweenLenConstraint, Data: []string{"abc", "5"}},
param: "abcd",
},
{
name: "betweenLen invalid second metadata",
constraint: Constraint{ID: betweenLenConstraint, Data: []string{"1", "abc"}},
param: "abcd",
},
{
name: "min invalid metadata",
constraint: Constraint{ID: minConstraint, Data: []string{"abc"}},
param: "10",
},
{
name: "max invalid metadata",
constraint: Constraint{ID: maxConstraint, Data: []string{"abc"}},
param: "10",
},
{
name: "range invalid first metadata",
constraint: Constraint{ID: rangeConstraint, Data: []string{"abc", "10"}},
param: "7",
},
{
name: "range invalid second metadata",
constraint: Constraint{ID: rangeConstraint, Data: []string{"1", "abc"}},
param: "7",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
require.False(t, testCase.constraint.CheckConstraint(testCase.param))
})
}
}
func Benchmark_Utils_RemoveEscapeChar(b *testing.B) {
b.ReportAllocs()
var res string
for b.Loop() {
res = RemoveEscapeChar(":test\\:bla")
}
require.Equal(b, ":test:bla", res)
}
// go test -race -run Test_Path_matchParams
func Benchmark_Path_matchParams(t *testing.B) {
var ctxParams [maxParams]string
benchCaseFn := func(testCollection routeCaseCollection) {
parser := parseRoute(testCollection.pattern)
for _, c := range testCollection.testCases {
var matchRes bool
state := "match"
if !c.match {
state = "not match"
}
t.Run(testCollection.pattern+" | "+state+" | "+c.url, func(b *testing.B) {
for b.Loop() {
if match := parser.getMatch(c.url, c.url, &ctxParams, c.partialCheck); match {
// Get testCases from the original path
matchRes = true
}
}
require.Equal(t, c.match, matchRes, "route: '%s', url: '%s'", testCollection.pattern, c.url)
if matchRes && len(c.params) > 0 {
require.Equal(t, c.params[0:len(c.params)-1], ctxParams[0:len(c.params)-1], "route: '%s', url: '%s'", testCollection.pattern, c.url)
}
})
}
}
for _, testCollection := range benchmarkCases {
benchCaseFn(testCollection)
}
}
// go test -race -run Test_RoutePatternMatch
func Benchmark_RoutePatternMatch(t *testing.B) {
benchCaseFn := func(testCollection routeCaseCollection) {
for _, c := range testCollection.testCases {
// skip all cases for partial checks
if c.partialCheck {
continue
}
var matchRes bool
state := "match"
if !c.match {
state = "not match"
}
t.Run(testCollection.pattern+" | "+state+" | "+c.url, func(b *testing.B) {
for b.Loop() {
if match := RoutePatternMatch(c.url, testCollection.pattern); match {
// Get testCases from the original path
matchRes = true
}
}
require.Equal(t, c.match, matchRes, "route: '%s', url: '%s'", testCollection.pattern, c.url)
})
}
}
for _, testCollection := range benchmarkCases {
benchCaseFn(testCollection)
}
}
func Test_Route_TooManyParams_Panic(t *testing.T) {
t.Parallel()
// Test with exactly maxParams (30) - should work
t.Run("exactly_maxParams", func(t *testing.T) {
t.Parallel()
route := paramsRoute(t, maxParams)
require.NotPanics(t, func() {
parseRoute(route)
})
})
// Test with maxParams + 1 (31) - should panic
t.Run("maxParams_plus_one", func(t *testing.T) {
t.Parallel()
route := paramsRoute(t, maxParams+1)
require.PanicsWithValue(t, "Route '"+route+"' has 31 parameters, which exceeds the maximum of 30", func() {
parseRoute(route)
})
})
// Test with 35 params - should panic
t.Run("35_params", func(t *testing.T) {
t.Parallel()
route := paramsRoute(t, maxParams+5)
require.PanicsWithValue(t, "Route '"+route+"' has 35 parameters, which exceeds the maximum of 30", func() {
parseRoute(route)
})
})
}
func Test_App_Register_TooManyParams_Panic(t *testing.T) {
t.Parallel()
// Test registering a route with too many params via app
t.Run("register_via_Get", func(t *testing.T) {
t.Parallel()
app := New()
route := paramsRoute(t, maxParams+1)
require.PanicsWithValue(t, "Route '"+route+"' has 31 parameters, which exceeds the maximum of 30", func() {
app.Get(route, func(c Ctx) error {
return c.SendString("test")
})
})
})
// Test registering a route with maxParams works
t.Run("register_maxParams_works", func(t *testing.T) {
t.Parallel()
app := New()
route := paramsRoute(t, maxParams)
require.NotPanics(t, func() {
app.Get(route, func(c Ctx) error {
return c.SendString("test")
})
})
})
}
// paramsRoute generates a route with n parameters for testing parseRoute maxParams condition.
// Returns a route in the format "/:p1/:p2/:p3/.../:pN"
func paramsRoute(t *testing.T, n int) string {
t.Helper()
params := make([]string, n)
for i := range params {
params[i] = fmt.Sprintf(":p%d", i+1)
}
return "/" + strings.Join(params, "/")
}
================================================
FILE: path_testcases_test.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 📝 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"strings"
)
type routeTestCase struct {
url string
params []string
match bool
partialCheck bool
}
type routeCaseCollection struct {
pattern string
testCases []routeTestCase
}
var (
benchmarkCases []routeCaseCollection
routeTestCases []routeCaseCollection
)
func init() {
// smaller list for benchmark cases
benchmarkCases = []routeCaseCollection{
{
pattern: "/api/v1/const",
testCases: []routeTestCase{
{url: "/api/v1/const", params: []string{}, match: true},
{url: "/api/v1", params: nil, match: false},
{url: "/api/v1/", params: nil, match: false},
{url: "/api/v1/something", params: nil, match: false},
},
},
{
pattern: "/api/:param/fixedEnd",
testCases: []routeTestCase{
{url: "/api/abc/fixedEnd", params: []string{"abc"}, match: true},
{url: "/api/abc/def/fixedEnd", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param/*",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: []string{"entity", ""}, match: true},
{url: "/api/v1/entity/", params: []string{"entity", ""}, match: true},
{url: "/api/v1/entity/1", params: []string{"entity", "1"}, match: true},
{url: "/api/v", params: nil, match: false},
{url: "/api/v2", params: nil, match: false},
{url: "/api/v1/", params: nil, match: false},
},
},
}
// combine benchmark cases and other cases
routeTestCases = benchmarkCases
routeTestCases = append(
routeTestCases,
[]routeCaseCollection{
{
pattern: "/api/v1/:param/+",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/entity/", params: nil, match: false},
{url: "/api/v1/entity/1", params: []string{"entity", "1"}, match: true},
{url: "/api/v", params: nil, match: false},
{url: "/api/v2", params: nil, match: false},
{url: "/api/v1/", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param?",
testCases: []routeTestCase{
{url: "/api/v1", params: []string{""}, match: true},
{url: "/api/v1/", params: []string{""}, match: true},
{url: "/api/v1/optional", params: []string{"optional"}, match: true},
{url: "/api/v", params: nil, match: false},
{url: "/api/v2", params: nil, match: false},
{url: "/api/xyz", params: nil, match: false},
},
},
{
pattern: `/v1/some/resource/name\:customVerb`,
testCases: []routeTestCase{
{url: "/v1/some/resource/name:customVerb", params: nil, match: true},
{url: "/v1/some/resource/name:test", params: nil, match: false},
},
},
{
pattern: `/v1/some/resource/:name\:customVerb`,
testCases: []routeTestCase{
{url: "/v1/some/resource/test:customVerb", params: []string{"test"}, match: true},
{url: "/v1/some/resource/test:test", params: nil, match: false},
},
},
{
pattern: `/v1/some/resource/name\\:customVerb?\?/:param/*`,
testCases: []routeTestCase{
{url: "/v1/some/resource/name:customVerb??/test/optionalWildCard/character", params: []string{"test", "optionalWildCard/character"}, match: true},
{url: "/v1/some/resource/name:customVerb??/test", params: []string{"test", ""}, match: true},
},
},
{
pattern: "/api/v1/*",
testCases: []routeTestCase{
{url: "/api/v1", params: []string{""}, match: true},
{url: "/api/v1/", params: []string{""}, match: true},
{url: "/api/v1/entity", params: []string{"entity"}, match: true},
{url: "/api/v1/entity/1/2", params: []string{"entity/1/2"}, match: true},
{url: "/api/v1/Entity/1/2", params: []string{"Entity/1/2"}, match: true},
{url: "/api/v", params: nil, match: false},
{url: "/api/v2", params: nil, match: false},
{url: "/api/abc", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: []string{"entity"}, match: true},
{url: "/api/v1/entity/8728382", params: nil, match: false},
{url: "/api/v1", params: nil, match: false},
{url: "/api/v1/", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param-:param2",
testCases: []routeTestCase{
{url: "/api/v1/entity-entity2", params: []string{"entity", "entity2"}, match: true},
{url: "/api/v1/entity/8728382", params: nil, match: false},
{url: "/api/v1/entity-8728382", params: []string{"entity", "8728382"}, match: true},
{url: "/api/v1", params: nil, match: false},
{url: "/api/v1/", params: nil, match: false},
},
},
{
pattern: "/api/v1/:filename.:extension",
testCases: []routeTestCase{
{url: "/api/v1/test.pdf", params: []string{"test", "pdf"}, match: true},
{url: "/api/v1/test/pdf", params: nil, match: false},
{url: "/api/v1/test-pdf", params: nil, match: false},
{url: "/api/v1/test_pdf", params: nil, match: false},
{url: "/api/v1", params: nil, match: false},
{url: "/api/v1/", params: nil, match: false},
},
},
{
pattern: "/shop/product/::filter/color::color/size::size",
testCases: []routeTestCase{
{url: "/shop/product/:test/color:blue/size:xs", params: []string{"test", "blue", "xs"}, match: true},
{url: "/shop/product/test/color:blue/size:xs", params: nil, match: false},
},
},
{
pattern: "/::param?",
testCases: []routeTestCase{
{url: "/:hello", params: []string{"hello"}, match: true},
{url: "/:", params: []string{""}, match: true},
{url: "/", params: nil, match: false},
},
},
// successive parameters, each take one character and the last parameter gets everything
{
pattern: "/test:sign:param",
testCases: []routeTestCase{
{url: "/test-abc", params: []string{"-", "abc"}, match: true},
{url: "/test", params: nil, match: false},
},
},
// optional parameters are not greedy
{
pattern: "/:param1:param2?:param3",
testCases: []routeTestCase{
{url: "/abbbc", params: []string{"a", "b", "bbc"}, match: true},
// {url: "/ac", testCases: []string{"a", "", "c"}, match: true}, // TODO: fix it
{url: "/test", params: []string{"t", "e", "st"}, match: true},
},
},
{
pattern: "/test:optional?:mandatory",
testCases: []routeTestCase{
// {url: "/testo", testCases: []string{"", "o"}, match: true}, // TODO: fix it
{url: "/testoaaa", params: []string{"o", "aaa"}, match: true},
{url: "/test", params: nil, match: false},
},
},
{
pattern: "/test:optional?:optional2?",
testCases: []routeTestCase{
{url: "/testo", params: []string{"o", ""}, match: true},
{url: "/testoaaa", params: []string{"o", "aaa"}, match: true},
{url: "/test", params: []string{"", ""}, match: true},
{url: "/tes", params: nil, match: false},
},
},
{
pattern: "/foo:param?bar",
testCases: []routeTestCase{
{url: "/foofaselbar", params: []string{"fasel"}, match: true},
{url: "/foobar", params: []string{""}, match: true},
{url: "/fooba", params: nil, match: false},
{url: "/fobar", params: nil, match: false},
},
},
{
pattern: "/foo*bar",
testCases: []routeTestCase{
{url: "/foofaselbar", params: []string{"fasel"}, match: true},
{url: "/foobar", params: []string{""}, match: true},
{url: "/", params: nil, match: false},
},
},
{
pattern: "/foo+bar",
testCases: []routeTestCase{
{url: "/foofaselbar", params: []string{"fasel"}, match: true},
{url: "/foobar", params: nil, match: false},
{url: "/", params: nil, match: false},
},
},
{
pattern: "/a*cde*g/",
testCases: []routeTestCase{
{url: "/abbbcdefffg", params: []string{"bbb", "fff"}, match: true},
{url: "/acdeg", params: []string{"", ""}, match: true},
{url: "/", params: nil, match: false},
},
},
{
pattern: "/*v1*/proxy",
testCases: []routeTestCase{
{url: "/customer/v1/cart/proxy", params: []string{"customer/", "/cart"}, match: true},
{url: "/v1/proxy", params: []string{"", ""}, match: true},
{url: "/v1/", params: nil, match: false},
},
},
// successive wildcard -> first wildcard is greedy
{
pattern: "/foo***bar",
testCases: []routeTestCase{
{url: "/foo*abar", params: []string{"*a", "", ""}, match: true},
{url: "/foo*bar", params: []string{"*", "", ""}, match: true},
{url: "/foobar", params: []string{"", "", ""}, match: true},
{url: "/fooba", params: nil, match: false},
},
},
// chars in front of a parameter
{
pattern: "/name::name",
testCases: []routeTestCase{
{url: "/name:john", params: []string{"john"}, match: true},
},
},
{
pattern: "/@:name",
testCases: []routeTestCase{
{url: "/@john", params: []string{"john"}, match: true},
},
},
{
pattern: "/-:name",
testCases: []routeTestCase{
{url: "/-john", params: []string{"john"}, match: true},
},
},
{
pattern: "/.:name",
testCases: []routeTestCase{
{url: "/.john", params: []string{"john"}, match: true},
},
},
{
pattern: "/api/v1/:param/abc/*",
testCases: []routeTestCase{
{url: "/api/v1/well/abc/wildcard", params: []string{"well", "wildcard"}, match: true},
{url: "/api/v1/well/abc/", params: []string{"well", ""}, match: true},
{url: "/api/v1/well/abc", params: []string{"well", ""}, match: true},
{url: "/api/v1/well/ttt", params: nil, match: false},
},
},
{
pattern: "/api/:day/:month?/:year?",
testCases: []routeTestCase{
{url: "/api/1", params: []string{"1", "", ""}, match: true},
{url: "/api/1/", params: []string{"1", "", ""}, match: true},
{url: "/api/1//", params: []string{"1", "", ""}, match: true},
{url: "/api/1/-/", params: []string{"1", "-", ""}, match: true},
{url: "/api/1-", params: []string{"1-", "", ""}, match: true},
{url: "/api/1.", params: []string{"1.", "", ""}, match: true},
{url: "/api/1/2", params: []string{"1", "2", ""}, match: true},
{url: "/api/1/2/3", params: []string{"1", "2", "3"}, match: true},
{url: "/api/", params: nil, match: false},
},
},
{
pattern: "/api/:day.:month?.:year?",
testCases: []routeTestCase{
{url: "/api/1", params: nil, match: false},
{url: "/api/1/", params: nil, match: false},
{url: "/api/1.", params: nil, match: false},
{url: "/api/1..", params: []string{"1", "", ""}, match: true},
{url: "/api/1.2", params: nil, match: false},
{url: "/api/1.2.", params: []string{"1", "2", ""}, match: true},
{url: "/api/1.2.3", params: []string{"1", "2", "3"}, match: true},
{url: "/api/", params: nil, match: false},
},
},
{
pattern: "/api/:day-:month?-:year?",
testCases: []routeTestCase{
{url: "/api/1", params: nil, match: false},
{url: "/api/1/", params: nil, match: false},
{url: "/api/1-", params: nil, match: false},
{url: "/api/1--", params: []string{"1", "", ""}, match: true},
{url: "/api/1-/", params: nil, match: false},
// {url: "/api/1-/-", testCases: nil, match: false}, // TODO: fix this part
{url: "/api/1-2", params: nil, match: false},
{url: "/api/1-2-", params: []string{"1", "2", ""}, match: true},
{url: "/api/1-2-3", params: []string{"1", "2", "3"}, match: true},
{url: "/api/", params: nil, match: false},
},
},
{
pattern: "/api/*",
testCases: []routeTestCase{
{url: "/api/", params: []string{""}, match: true},
{url: "/api/joker", params: []string{"joker"}, match: true},
{url: "/api", params: []string{""}, match: true},
{url: "/api/v1/entity", params: []string{"v1/entity"}, match: true},
{url: "/api2/v1/entity", params: nil, match: false},
{url: "/api_ignore/v1/entity", params: nil, match: false},
},
},
{
pattern: "/partialCheck/foo/",
testCases: []routeTestCase{
{url: "/partialCheck/foo/", params: nil, match: true, partialCheck: true},
{url: "/partialCheck/foo/bar", params: nil, match: true, partialCheck: true},
{url: "/partialCheck/foo/bar/baz", params: nil, match: true, partialCheck: true},
{url: "/partialCheck/foobar", params: nil, match: false, partialCheck: true},
},
},
{
pattern: "/partialCheck/foo/bar/:param",
testCases: []routeTestCase{
{url: "/partialCheck/foo/bar/test", params: []string{"test"}, match: true, partialCheck: true},
{url: "/partialCheck/foo/bar/test/test2", params: []string{"test"}, match: true, partialCheck: true},
{url: "/partialCheck/foo/bar", params: nil, match: false, partialCheck: true},
{url: "/partialFoo", params: nil, match: false, partialCheck: true},
},
},
{
pattern: "/partialCheck/foo",
testCases: []routeTestCase{
{url: "/partialCheck/foo", params: nil, match: true, partialCheck: true},
{url: "/partialCheck/foo/", params: nil, match: true, partialCheck: true},
{url: "/partialCheck/foo/bar", params: nil, match: true, partialCheck: true},
{url: "/partialCheck/foobar", params: nil, match: false, partialCheck: true},
},
},
{
pattern: "/partialCheck/:param",
testCases: []routeTestCase{
{url: "/partialCheck/value", params: []string{"value"}, match: true, partialCheck: true},
{url: "/partialCheck/value/", params: []string{"value"}, match: true, partialCheck: true},
{url: "/partialCheck/value/next", params: []string{"value"}, match: true, partialCheck: true},
},
},
{
pattern: "/",
testCases: []routeTestCase{
{url: "/api", params: nil, match: false},
{url: "", params: []string{}, match: true},
{url: "/", params: []string{}, match: true},
},
},
{
pattern: "/config/abc.json",
testCases: []routeTestCase{
{url: "/config/abc.json", params: []string{}, match: true},
{url: "config/abc.json", params: nil, match: false},
{url: "/config/efg.json", params: nil, match: false},
{url: "/config", params: nil, match: false},
},
},
{
pattern: "/config/*.json",
testCases: []routeTestCase{
{url: "/config/abc.json", params: []string{"abc"}, match: true},
{url: "/config/efg.json", params: []string{"efg"}, match: true},
{url: "/config/.json", params: []string{""}, match: true},
{url: "/config/efg.csv", params: nil, match: false},
{url: "config/abc.json", params: nil, match: false},
{url: "/config", params: nil, match: false},
},
},
{
pattern: "/config/+.json",
testCases: []routeTestCase{
{url: "/config/abc.json", params: []string{"abc"}, match: true},
{url: "/config/.json", params: nil, match: false},
{url: "/config/efg.json", params: []string{"efg"}, match: true},
{url: "/config/efg.csv", params: nil, match: false},
{url: "config/abc.json", params: nil, match: false},
{url: "/config", params: nil, match: false},
},
},
{
pattern: "/xyz",
testCases: []routeTestCase{
{url: "xyz", params: nil, match: false},
{url: "xyz/", params: nil, match: false},
},
},
{
pattern: "/api/*/:param?",
testCases: []routeTestCase{
{url: "/api/", params: []string{"", ""}, match: true},
{url: "/api/joker", params: []string{"joker", ""}, match: true},
{url: "/api/joker/batman", params: []string{"joker", "batman"}, match: true},
{url: "/api/joker//batman", params: []string{"joker/", "batman"}, match: true},
{url: "/api/joker/batman/robin", params: []string{"joker/batman", "robin"}, match: true},
{url: "/api/joker/batman/robin/1", params: []string{"joker/batman/robin", "1"}, match: true},
{url: "/api/joker/batman/robin/1/", params: []string{"joker/batman/robin/1", ""}, match: true},
{url: "/api/joker-batman/robin/1", params: []string{"joker-batman/robin", "1"}, match: true},
{url: "/api/joker-batman-robin/1", params: []string{"joker-batman-robin", "1"}, match: true},
{url: "/api/joker-batman-robin-1", params: []string{"joker-batman-robin-1", ""}, match: true},
{url: "/api", params: []string{"", ""}, match: true},
},
},
{
pattern: "/api/*/:param",
testCases: []routeTestCase{
{url: "/api/test/abc", params: []string{"test", "abc"}, match: true},
{url: "/api/joker/batman", params: []string{"joker", "batman"}, match: true},
{url: "/api/joker/batman/robin", params: []string{"joker/batman", "robin"}, match: true},
{url: "/api/joker/batman/robin/1", params: []string{"joker/batman/robin", "1"}, match: true},
{url: "/api/joker/batman-robin/1", params: []string{"joker/batman-robin", "1"}, match: true},
{url: "/api/joker-batman-robin-1", params: nil, match: false},
{url: "/api", params: nil, match: false},
},
},
{
pattern: "/api/+/:param",
testCases: []routeTestCase{
{url: "/api/test/abc", params: []string{"test", "abc"}, match: true},
{url: "/api/joker/batman/robin/1", params: []string{"joker/batman/robin", "1"}, match: true},
{url: "/api/joker", params: nil, match: false},
{url: "/api", params: nil, match: false},
},
},
{
pattern: "/api/*/:param/:param2",
testCases: []routeTestCase{
{url: "/api/test/abc/1", params: []string{"test", "abc", "1"}, match: true},
{url: "/api/joker/batman", params: nil, match: false},
{url: "/api/joker/batman-robin/1", params: []string{"joker", "batman-robin", "1"}, match: true},
{url: "/api/joker-batman-robin-1", params: nil, match: false},
{url: "/api/test/abc", params: nil, match: false},
{url: "/api/joker/batman/robin", params: []string{"joker", "batman", "robin"}, match: true},
{url: "/api/joker/batman/robin/1", params: []string{"joker/batman", "robin", "1"}, match: true},
{url: "/api/joker/batman/robin/1/2", params: []string{"joker/batman/robin", "1", "2"}, match: true},
{url: "/api", params: nil, match: false},
{url: "/api/:test", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/8728382", params: []string{"8728382"}, match: true},
{url: "/api/v1/true", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/8728382", params: nil, match: false},
{url: "/api/v1/true", params: []string{"true"}, match: true},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/8728382", params: []string{"8728382"}, match: true},
{url: "/api/v1/8728382.5", params: []string{"8728382.5"}, match: true},
{url: "/api/v1/true", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: []string{"entity"}, match: true},
{url: "/api/v1/#!?", params: nil, match: false},
{url: "/api/v1/8728382", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/8728382", params: nil, match: false},
{url: "/api/v1/f0fa66cc-d22e-445b-866d-1d76e776371d", params: []string{"f0fa66cc-d22e-445b-866d-1d76e776371d"}, match: true},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/8728382", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: []string{"entity"}, match: true},
{url: "/api/v1/ent", params: nil, match: false},
{url: "/api/v1/8728382", params: []string{"8728382"}, match: true},
{url: "/api/v1/123", params: nil, match: false},
{url: "/api/v1/12345", params: []string{"12345"}, match: true},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/ent", params: []string{"ent"}, match: true},
{url: "/api/v1/8728382", params: nil, match: false},
{url: "/api/v1/123", params: []string{"123"}, match: true},
{url: "/api/v1/12345", params: []string{"12345"}, match: true},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/ent", params: nil, match: false},
{url: "/api/v1/123", params: nil, match: false},
{url: "/api/v1/12345", params: []string{"12345"}, match: true},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/ent", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/e", params: nil, match: false},
{url: "/api/v1/en", params: []string{"en"}, match: true},
{url: "/api/v1/8728382", params: nil, match: false},
{url: "/api/v1/123", params: []string{"123"}, match: true},
{url: "/api/v1/12345", params: []string{"12345"}, match: true},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/e", params: nil, match: false},
{url: "/api/v1/en", params: []string{"en"}, match: true},
{url: "/api/v1/8728382", params: nil, match: false},
{url: "/api/v1/123", params: []string{"123"}, match: true},
{url: "/api/v1/12345", params: []string{"12345"}, match: true},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/ent", params: nil, match: false},
{url: "/api/v1/1", params: nil, match: false},
{url: "/api/v1/5", params: []string{"5"}, match: true},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/ent", params: nil, match: false},
{url: "/api/v1/1", params: []string{"1"}, match: true},
{url: "/api/v1/5", params: []string{"5"}, match: true},
{url: "/api/v1/15", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/ent", params: nil, match: false},
{url: "/api/v1/9", params: []string{"9"}, match: true},
{url: "/api/v1/5", params: []string{"5"}, match: true},
{url: "/api/v1/15", params: nil, match: false},
},
},
{
pattern: `/api/v1/:param`,
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/8728382", params: nil, match: false},
{url: "/api/v1/2005-11-01", params: []string{"2005-11-01"}, match: true},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/ent", params: nil, match: false},
{url: "/api/v1/15", params: nil, match: false},
{url: "/api/v1/peach", params: []string{"peach"}, match: true},
{url: "/api/v1/p34ch", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/12", params: nil, match: false},
{url: "/api/v1/xy", params: nil, match: false},
{url: "/api/v1/test", params: []string{"test"}, match: true},
{url: "/api/v1/" + strings.Repeat("a", 64), params: nil, match: false},
},
},
{
pattern: `/api/v1/:param`,
testCases: []routeTestCase{
{url: "/api/v1/ent", params: nil, match: false},
{url: "/api/v1/15", params: nil, match: false},
{url: "/api/v1/2022-08-27", params: []string{"2022-08-27"}, match: true},
{url: "/api/v1/2022/08-27", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/8728382", params: []string{"8728382"}, match: true},
{url: "/api/v1/true", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/8728382", params: nil, match: false},
{url: "/api/v1/123", params: []string{"123"}, match: true},
{url: "/api/v1/true", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/87283827683", params: nil, match: false},
{url: "/api/v1/123", params: []string{"123"}, match: true},
{url: "/api/v1/true", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/87283827683", params: nil, match: false},
{url: "/api/v1/25", params: []string{"25"}, match: true},
{url: "/api/v1/true", params: nil, match: false},
},
},
{
pattern: `/api/v1/:param`,
testCases: []routeTestCase{
{url: "/api/v1/entity", params: []string{"entity"}, match: true},
{url: "/api/v1/87283827683", params: []string{"87283827683"}, match: true},
{url: "/api/v1/25", params: []string{"25"}, match: true},
{url: "/api/v1/true", params: []string{"true"}, match: true},
},
},
{
pattern: `/api/v1/:param`,
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/87283827683", params: nil, match: false},
{url: "/api/v1/25", params: nil, match: false},
{url: "/api/v1/1200", params: nil, match: false},
{url: "/api/v1/true", params: nil, match: false},
},
},
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/87283827683", params: nil, match: false},
{url: "/api/v1/25", params: []string{"25"}, match: true},
{url: "/api/v1/1200", params: []string{"1200"}, match: true},
{url: "/api/v1/true", params: nil, match: false},
},
},
{
pattern: "/api/v1/:lang/videos/:page",
testCases: []routeTestCase{
{url: "/api/v1/try/videos/200", params: nil, match: false},
{url: "/api/v1/tr/videos/1800", params: nil, match: false},
{url: "/api/v1/tr/videos/100", params: []string{"tr", "100"}, match: true},
{url: "/api/v1/e/videos/10", params: nil, match: false},
},
},
{
pattern: "/api/v1/:lang/:page",
testCases: []routeTestCase{
{url: "/api/v1/try/200", params: nil, match: false},
{url: "/api/v1/tr/1800", params: nil, match: false},
{url: "/api/v1/tr/100", params: []string{"tr", "100"}, match: true},
{url: "/api/v1/e/10", params: nil, match: false},
},
},
{
pattern: "/api/v1/:lang/:page",
testCases: []routeTestCase{
{url: "/api/v1/try/200", params: []string{"try", "200"}, match: true},
{url: "/api/v1/tr/1800", params: nil, match: false},
{url: "/api/v1/tr/100", params: []string{"tr", "100"}, match: true},
{url: "/api/v1/e/10", params: nil, match: false},
},
},
{
pattern: "/api/v1/:lang/:page",
testCases: []routeTestCase{
{url: "/api/v1/try/200", params: nil, match: false},
{url: "/api/v1/tr/1800", params: []string{"tr", "1800"}, match: true},
{url: "/api/v1/tr/100", params: []string{"tr", "100"}, match: true},
{url: "/api/v1/e/10", params: nil, match: false},
},
},
{
pattern: `/api/v1/:date/:regex`,
testCases: []routeTestCase{
{url: "/api/v1/2005-11-01/a", params: nil, match: false},
{url: "/api/v1/2005-1101/paach", params: nil, match: false},
{url: "/api/v1/2005-11-01/peach", params: []string{"2005-11-01", "peach"}, match: true},
},
},
{
pattern: "/api/v1/:param?",
testCases: []routeTestCase{
{url: "/api/v1/entity", params: nil, match: false},
{url: "/api/v1/8728382", params: []string{"8728382"}, match: true},
{url: "/api/v1/true", params: nil, match: false},
{url: "/api/v1/", params: []string{""}, match: true},
},
},
// Add test case for RegexCompiler == nil
{
pattern: "/api/v1/:param",
testCases: []routeTestCase{
{url: "/api/v1/123", params: []string{"123"}, match: true},
{url: "/api/v1/abc", params: nil, match: false},
},
},
}...,
)
}
================================================
FILE: prefork.go
================================================
package fiber
import (
"crypto/tls"
"errors"
"fmt"
"net"
"os"
"os/exec"
"runtime"
"sync/atomic"
"time"
"github.com/valyala/fasthttp/reuseport"
"github.com/gofiber/fiber/v3/log"
)
const (
envPreforkChildKey = "FIBER_PREFORK_CHILD"
envPreforkChildVal = "1"
sleepDuration = 100 * time.Millisecond
windowsOS = "windows"
)
var (
testPreforkMaster = false
testOnPrefork = false
)
// IsChild determines if the current process is a child of Prefork
func IsChild() bool {
return os.Getenv(envPreforkChildKey) == envPreforkChildVal
}
// prefork manages child processes to make use of the OS REUSEPORT or REUSEADDR feature
func (app *App) prefork(addr string, tlsConfig *tls.Config, cfg *ListenConfig) error {
if cfg == nil {
cfg = &ListenConfig{}
}
var ln net.Listener
var err error
// 👶 child process 👶
if IsChild() {
// use 1 cpu core per child process
runtime.GOMAXPROCS(1)
// Linux will use SO_REUSEPORT and Windows falls back to SO_REUSEADDR
// Only tcp4 or tcp6 is supported when preforking, both are not supported
if ln, err = reuseport.Listen(cfg.ListenerNetwork, addr); err != nil {
if !cfg.DisableStartupMessage {
time.Sleep(sleepDuration) // avoid colliding with startup message
}
return fmt.Errorf("prefork: %w", err)
}
// wrap a tls config around the listener if provided
if tlsConfig != nil {
ln = tls.NewListener(ln, tlsConfig)
}
// kill current child proc when master exits
masterPID := os.Getppid()
go watchMaster(masterPID)
// prepare the server for the start
app.startupProcess()
if cfg.ListenerAddrFunc != nil {
cfg.ListenerAddrFunc(ln.Addr())
}
// listen for incoming connections
return app.server.Serve(ln)
}
// 👮 master process 👮
type child struct {
err error
pid int
}
// create variables
maxProcs := runtime.GOMAXPROCS(0)
children := make(map[int]*exec.Cmd)
channel := make(chan child, maxProcs)
// kill child procs when master exits
defer func() {
for _, proc := range children {
if err = proc.Process.Kill(); err != nil {
if !errors.Is(err, os.ErrProcessDone) {
log.Errorf("prefork: failed to kill child: %v", err)
}
}
}
}()
// collect child pids
var childPIDs []int
// launch child procs
for range maxProcs {
cmd := exec.Command(os.Args[0], os.Args[1:]...) //nolint:gosec // It's fine to launch the same process again
if testPreforkMaster {
// When test prefork master,
// just start the child process with a dummy cmd,
// which will exit soon
cmd = dummyCmd()
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// add fiber prefork child flag into child proc env
cmd.Env = append(os.Environ(),
fmt.Sprintf("%s=%s", envPreforkChildKey, envPreforkChildVal),
)
if err = cmd.Start(); err != nil {
return fmt.Errorf("failed to start a child prefork process, error: %w", err)
}
// store child process
pid := cmd.Process.Pid
children[pid] = cmd
childPIDs = append(childPIDs, pid)
// execute fork hook
if app.hooks != nil {
if testOnPrefork {
app.hooks.executeOnForkHooks(dummyPid)
} else {
app.hooks.executeOnForkHooks(pid)
}
}
// notify master if child crashes
go func() {
channel <- child{pid: pid, err: cmd.Wait()}
}()
}
// Run onListen hooks
// Hooks have to be run here as different as non-prefork mode due to they should run as child or master
listenData := app.prepareListenData(addr, tlsConfig != nil, cfg, childPIDs)
app.runOnListenHooks(listenData)
app.startupMessage(listenData, cfg)
if cfg.EnablePrintRoutes {
app.printRoutesMessage()
}
// return error if child crashes
return (<-channel).err
}
// watchMaster watches the master process and exits if it dies.
// It detects master death by checking if the parent PID has changed,
// which happens when the master exits and the child is reparented to
// another process (often init/PID 1, but could be a subreaper).
func watchMaster(masterPID int) {
if runtime.GOOS == windowsOS {
// finds parent process,
// and waits for it to exit
p, err := os.FindProcess(masterPID)
if err == nil {
_, _ = p.Wait() //nolint:errcheck // It is fine to ignore the error here
}
os.Exit(1) //nolint:revive // Calling os.Exit is fine here in the prefork
}
// Watch for parent PID changes. When the master exits, the OS
// reparents the child to another process, causing Getppid() to change.
// Comparing against the original PID instead of hardcoding 1 ensures
// this works correctly when the master itself is PID 1 (e.g. in
// Docker containers).
const watchInterval = 500 * time.Millisecond
for range time.NewTicker(watchInterval).C {
if os.Getppid() != masterPID {
os.Exit(1) //nolint:revive // Calling os.Exit is fine here in the prefork
}
}
}
var (
dummyPid = 1
dummyChildCmd atomic.Value
)
// dummyCmd is for internal prefork testing
func dummyCmd() *exec.Cmd {
command := "go"
if storeCommand := dummyChildCmd.Load(); storeCommand != nil && storeCommand != "" {
command = storeCommand.(string) //nolint:forcetypeassert,errcheck // We always store a string in here
}
if runtime.GOOS == windowsOS {
return exec.Command("cmd", "/C", command, "version")
}
return exec.Command(command, "version")
}
================================================
FILE: prefork_test.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 📄 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
// 💖 Maintained and modified for Fiber by @renewerner87
package fiber
import (
"crypto/tls"
"io"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_App_Prefork_Child_Process(t *testing.T) {
// Reset test var
testPreforkMaster = true
setupIsChild(t)
app := New()
cfg := listenConfigDefault()
err := app.prefork("invalid", nil, &cfg)
require.Error(t, err)
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
ipv6Cfg := ListenConfig{ListenerNetwork: NetworkTCP6}
require.NoError(t, app.prefork("[::1]:", nil, &ipv6Cfg))
// Create tls certificate
cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")
if err != nil {
require.NoError(t, err)
}
//nolint:gosec // We're in a test so using old ciphers is fine
config := &tls.Config{Certificates: []tls.Certificate{cer}}
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
cfg = listenConfigDefault()
require.NoError(t, app.prefork("127.0.0.1:", config, &cfg))
}
func Test_App_Prefork_Master_Process(t *testing.T) {
// Reset test var
testPreforkMaster = true
app := New()
go func() {
time.Sleep(1000 * time.Millisecond)
assert.NoError(t, app.Shutdown())
}()
cfg := listenConfigDefault()
require.NoError(t, app.prefork(":0", nil, &cfg))
dummyChildCmd.Store("invalid")
cfg = listenConfigDefault()
err := app.prefork("127.0.0.1:", nil, &cfg)
require.Error(t, err)
dummyChildCmd.Store("go")
}
func Test_App_Prefork_Child_Process_Never_Show_Startup_Message(t *testing.T) {
setupIsChild(t)
rescueStdout := os.Stdout
defer func() { os.Stdout = rescueStdout }()
r, w, err := os.Pipe()
require.NoError(t, err)
os.Stdout = w
cfg := listenConfigDefault()
app := New()
app.startupProcess()
listenData := app.prepareListenData(":0", false, &cfg, nil)
app.startupMessage(listenData, &cfg)
require.NoError(t, w.Close())
out, err := io.ReadAll(r)
require.NoError(t, err)
require.Empty(t, out)
}
func setupIsChild(t *testing.T) {
t.Helper()
t.Setenv(envPreforkChildKey, envPreforkChildVal)
}
================================================
FILE: readonly.go
================================================
//go:build !s390x && !ppc64 && !ppc64le
package fiber
import (
"unsafe"
)
//go:linkname runtimeRodata runtime.rodata
var runtimeRodata byte
//go:linkname runtimeErodata runtime.erodata
var runtimeErodata byte
func isReadOnly(p unsafe.Pointer) bool {
start := uintptr(unsafe.Pointer(&runtimeRodata)) //nolint:gosec // converting runtime symbols
end := uintptr(unsafe.Pointer(&runtimeErodata)) //nolint:gosec // converting runtime symbols
addr := uintptr(p)
return addr >= start && addr < end
}
================================================
FILE: readonly_strict.go
================================================
//go:build s390x || ppc64 || ppc64le
package fiber
import "unsafe"
func isReadOnly(_ unsafe.Pointer) bool {
return false
}
================================================
FILE: redirect.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 📝 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"bytes"
"encoding/hex"
"sync"
"github.com/gofiber/utils/v2"
utilsbytes "github.com/gofiber/utils/v2/bytes"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
"github.com/gofiber/fiber/v3/binder"
)
// Pool for redirection
var (
redirectPool = sync.Pool{
New: func() any {
return &Redirect{
status: StatusSeeOther,
messages: make(redirectionMsgs, 0),
}
},
}
oldInputPool = sync.Pool{
New: func() any {
return make(map[string]string)
},
}
)
const maxPoolableMapSize = 64
// FlashCookieName Cookie name to send flash messages when to use redirection.
const (
FlashCookieName = "fiber_flash"
)
var flashCookieNeedle = []byte(FlashCookieName + "=")
// hasFlashCookie is on the request hot path and runs on every request/response cycle.
// Keep this cheap for users who don't use flash messages:
// 1) a fast raw-header prefilter to avoid unnecessary cookie parsing,
// 2) an exact cookie lookup to avoid prefix false positives (e.g. fiber_flashX).
func hasFlashCookie(header *fasthttp.RequestHeader) bool {
rawHeaders := header.RawHeaders()
if len(rawHeaders) == 0 {
return false
}
if !bytes.Contains(rawHeaders, flashCookieNeedle) {
return false
}
return header.Cookie(FlashCookieName) != nil
}
// redirectionMsgs is a struct that used to store flash messages and old input data in cookie using MSGP.
// msgp -file="redirect.go" -o="redirect_msgp.go" -unexported
// Cookie payloads are limited to ~4KB, so keep flash message counts bounded but usable.
//
//msgp:limit arrays:256 maps:32 marshal:true
//msgp:ignore Redirect RedirectConfig OldInputData FlashMessage
type redirectionMsg struct {
key string
value string
level uint8
isOldInput bool
}
type redirectionMsgs []redirectionMsg
// OldInputData is a struct that holds the old input data.
type OldInputData struct {
Key string
Value string
}
// FlashMessage is a struct that holds the flash message data.
type FlashMessage struct {
Key string
Value string
Level uint8
}
// Redirect is a struct that holds the redirect data.
type Redirect struct {
c *DefaultCtx // Embed ctx
messages redirectionMsgs // Flash messages and old input data
status int // Status code of redirection. Default: 303 StatusSeeOther
}
// RedirectConfig A config to use with Redirect().Route()
// You can specify queries or route parameters.
// NOTE: We don't use net/url to parse parameters because of it has poor performance. You have to pass map.
type RedirectConfig struct {
Params Map // Route parameters
Queries map[string]string // Query map
}
// AcquireRedirect return default Redirect reference from the redirect pool
func AcquireRedirect() *Redirect {
redirect, ok := redirectPool.Get().(*Redirect)
if !ok {
panic(errRedirectTypeAssertion)
}
return redirect
}
// ReleaseRedirect returns c acquired via Redirect to redirect pool.
//
// It is forbidden accessing req and/or its members after returning
// it to redirect pool.
func ReleaseRedirect(r *Redirect) {
r.release()
redirectPool.Put(r)
}
func (r *Redirect) release() {
r.status = StatusSeeOther
r.messages = r.messages[:0]
r.c = nil
}
func acquireOldInput() map[string]string {
oldInput, ok := oldInputPool.Get().(map[string]string)
if !ok {
return make(map[string]string)
}
return oldInput
}
func releaseOldInput(oldInput map[string]string) {
if len(oldInput) > maxPoolableMapSize {
return
}
clear(oldInput)
oldInputPool.Put(oldInput)
}
// Status sets the status code of redirection.
// If status is not specified, status defaults to 303 See Other.
func (r *Redirect) Status(code int) *Redirect {
r.status = code
return r
}
// With You can send flash messages by using With().
// They will be sent as a cookie.
// You can get them by using: Redirect().Messages(), Redirect().Message()
// Note: You must use escape char before using ',' and ':' chars to avoid wrong parsing.
func (r *Redirect) With(key, value string, level ...uint8) *Redirect {
// Get level
var msgLevel uint8
if len(level) > 0 {
msgLevel = level[0]
}
// Override old message if exists
for i, msg := range r.messages {
if msg.key == key && !msg.isOldInput {
r.messages[i].value = value
r.messages[i].level = msgLevel
return r
}
}
r.messages = append(r.messages, redirectionMsg{
key: key,
value: value,
level: msgLevel,
})
return r
}
// WithInput You can send input data by using WithInput().
// They will be sent as a cookie.
// This method can send form, multipart form, query data to redirected route.
// You can get them by using: Redirect().OldInputs(), Redirect().OldInput()
func (r *Redirect) WithInput() *Redirect {
// Get content-type
ctype := utils.UnsafeString(utilsbytes.UnsafeToLower(r.c.RequestCtx().Request.Header.ContentType()))
ctype = binder.FilterFlags(utils.ParseVendorSpecificContentType(ctype))
oldInput := acquireOldInput()
defer releaseOldInput(oldInput)
switch ctype {
case MIMEApplicationForm, MIMEMultipartForm:
_ = r.c.Bind().Form(oldInput) //nolint:errcheck // not needed
default:
_ = r.c.Bind().Query(oldInput) //nolint:errcheck // not needed
}
// Add old input data
for k, v := range oldInput {
r.messages = append(r.messages, redirectionMsg{
key: k,
value: v,
isOldInput: true,
})
}
return r
}
// Messages Get flash messages.
func (r *Redirect) Messages() []FlashMessage {
if len(r.c.flashMessages) == 0 {
return nil
}
flashMessages := make([]FlashMessage, 0, len(r.c.flashMessages))
writeIdx := 0
for _, msg := range r.c.flashMessages {
if msg.isOldInput {
r.c.flashMessages[writeIdx] = msg
writeIdx++
continue
}
flashMessages = append(flashMessages, FlashMessage{
Key: msg.key,
Value: msg.value,
Level: msg.level,
})
}
for i := writeIdx; i < len(r.c.flashMessages); i++ {
r.c.flashMessages[i] = redirectionMsg{}
}
r.c.flashMessages = r.c.flashMessages[:writeIdx]
return flashMessages
}
// Message Get flash message by key.
func (r *Redirect) Message(key string) FlashMessage {
if len(r.c.flashMessages) == 0 {
return FlashMessage{}
}
var flashMessage FlashMessage
found := false
writeIdx := 0
for _, msg := range r.c.flashMessages {
if msg.isOldInput || found || msg.key != key {
r.c.flashMessages[writeIdx] = msg
writeIdx++
continue
}
flashMessage = FlashMessage{
Key: msg.key,
Value: msg.value,
Level: msg.level,
}
found = true
}
for i := writeIdx; i < len(r.c.flashMessages); i++ {
r.c.flashMessages[i] = redirectionMsg{}
}
r.c.flashMessages = r.c.flashMessages[:writeIdx]
return flashMessage
}
// OldInputs Get old input data.
func (r *Redirect) OldInputs() []OldInputData {
// Count old inputs first to avoid allocation if none exist
count := 0
for _, msg := range r.c.flashMessages {
if msg.isOldInput {
count++
}
}
if count == 0 {
return nil
}
inputs := make([]OldInputData, 0, count)
for _, msg := range r.c.flashMessages {
if msg.isOldInput {
inputs = append(inputs, OldInputData{
Key: msg.key,
Value: msg.value,
})
}
}
return inputs
}
// OldInput Get old input data by key.
func (r *Redirect) OldInput(key string) OldInputData {
msgs := r.c.flashMessages
for _, msg := range msgs {
if msg.key == key && msg.isOldInput {
return OldInputData{
Key: msg.key,
Value: msg.value,
}
}
}
return OldInputData{}
}
// To redirect to the URL derived from the specified path, with specified status.
func (r *Redirect) To(location string) error {
r.c.setCanonical(HeaderLocation, location)
r.c.Status(r.status)
r.processFlashMessages()
return nil
}
// Route redirects to the Route registered in the app with appropriate parameters.
// If you want to send queries or params to route, you should use config parameter.
func (r *Redirect) Route(name string, config ...RedirectConfig) error {
// Check config
cfg := RedirectConfig{}
if len(config) > 0 {
cfg = config[0]
}
// Get location from route name
route := r.c.App().GetRoute(name)
location, err := r.c.getLocationFromRoute(&route, cfg.Params)
if err != nil {
return err
}
// Check queries
if len(cfg.Queries) > 0 {
queryText := bytebufferpool.Get()
defer bytebufferpool.Put(queryText)
first := true
for k, v := range cfg.Queries {
if !first {
queryText.WriteByte('&')
}
first = false
queryText.WriteString(k)
queryText.WriteByte('=')
queryText.WriteString(v)
}
return r.To(location + "?" + r.c.app.toString(queryText.Bytes()))
}
return r.To(location)
}
// Back redirect to the URL to referer.
func (r *Redirect) Back(fallback ...string) error {
location := r.c.Get(HeaderReferer)
if location == "" {
// Check fallback URL
if len(fallback) == 0 {
err := ErrRedirectBackNoFallback
r.c.Status(err.Code)
return err
}
location = fallback[0]
}
return r.To(location)
}
// parseAndClearFlashMessages is a method to get flash messages before they are getting removed
func (r *Redirect) parseAndClearFlashMessages() {
// parse flash messages
cookieValue, err := hex.DecodeString(r.c.Cookies(FlashCookieName))
if err != nil {
return
}
_, err = r.c.flashMessages.UnmarshalMsg(cookieValue)
if err != nil {
return
}
r.c.Cookie(&Cookie{
Name: FlashCookieName,
Value: "",
Path: "/",
MaxAge: -1,
})
}
// processFlashMessages is a helper function to process flash messages and old input data
// and set them as cookies
func (r *Redirect) processFlashMessages() {
if len(r.messages) == 0 {
return
}
val, err := r.messages.MarshalMsg(nil)
if err != nil {
return
}
dst := make([]byte, hex.EncodedLen(len(val)))
hex.Encode(dst, val)
r.c.Cookie(&Cookie{
Name: FlashCookieName,
Value: r.c.app.toString(dst),
SessionOnly: true,
})
}
================================================
FILE: redirect_msgp.go
================================================
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
package fiber
import (
"github.com/tinylib/msgp/msgp"
)
// Size limits for msgp deserialization
const (
zc920acdalimitArrays = 256
zc920acdalimitMaps = 32
)
// DecodeMsg implements msgp.Decodable
func (z *redirectionMsg) DecodeMsg(dc *msgp.Reader) (err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, err = dc.ReadMapHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
if zb0001 > zc920acdalimitMaps {
err = msgp.ErrLimitExceeded
return
}
for zb0001 > 0 {
zb0001--
field, err = dc.ReadMapKeyPtr()
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "key":
z.key, err = dc.ReadString()
if err != nil {
err = msgp.WrapError(err, "key")
return
}
case "value":
z.value, err = dc.ReadString()
if err != nil {
err = msgp.WrapError(err, "value")
return
}
case "level":
z.level, err = dc.ReadUint8()
if err != nil {
err = msgp.WrapError(err, "level")
return
}
case "isOldInput":
z.isOldInput, err = dc.ReadBool()
if err != nil {
err = msgp.WrapError(err, "isOldInput")
return
}
default:
err = dc.Skip()
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z *redirectionMsg) EncodeMsg(en *msgp.Writer) (err error) {
// map header, size 4
// write "key"
err = en.Append(0x84, 0xa3, 0x6b, 0x65, 0x79)
if err != nil {
return
}
err = en.WriteString(z.key)
if err != nil {
err = msgp.WrapError(err, "key")
return
}
// write "value"
err = en.Append(0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65)
if err != nil {
return
}
err = en.WriteString(z.value)
if err != nil {
err = msgp.WrapError(err, "value")
return
}
// write "level"
err = en.Append(0xa5, 0x6c, 0x65, 0x76, 0x65, 0x6c)
if err != nil {
return
}
err = en.WriteUint8(z.level)
if err != nil {
err = msgp.WrapError(err, "level")
return
}
// write "isOldInput"
err = en.Append(0xaa, 0x69, 0x73, 0x4f, 0x6c, 0x64, 0x49, 0x6e, 0x70, 0x75, 0x74)
if err != nil {
return
}
err = en.WriteBool(z.isOldInput)
if err != nil {
err = msgp.WrapError(err, "isOldInput")
return
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z *redirectionMsg) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
// map header, size 4
// string "key"
o = append(o, 0x84, 0xa3, 0x6b, 0x65, 0x79)
o = msgp.AppendString(o, z.key)
// string "value"
o = append(o, 0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65)
o = msgp.AppendString(o, z.value)
// string "level"
o = append(o, 0xa5, 0x6c, 0x65, 0x76, 0x65, 0x6c)
o = msgp.AppendUint8(o, z.level)
// string "isOldInput"
o = append(o, 0xaa, 0x69, 0x73, 0x4f, 0x6c, 0x64, 0x49, 0x6e, 0x70, 0x75, 0x74)
o = msgp.AppendBool(o, z.isOldInput)
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *redirectionMsg) UnmarshalMsg(bts []byte) (o []byte, err error) {
var field []byte
_ = field
var zb0001 uint32
zb0001, bts, err = msgp.ReadMapHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
if zb0001 > zc920acdalimitMaps {
err = msgp.ErrLimitExceeded
return
}
for zb0001 > 0 {
zb0001--
field, bts, err = msgp.ReadMapKeyZC(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
switch msgp.UnsafeString(field) {
case "key":
z.key, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "key")
return
}
case "value":
z.value, bts, err = msgp.ReadStringBytes(bts)
if err != nil {
err = msgp.WrapError(err, "value")
return
}
case "level":
z.level, bts, err = msgp.ReadUint8Bytes(bts)
if err != nil {
err = msgp.WrapError(err, "level")
return
}
case "isOldInput":
z.isOldInput, bts, err = msgp.ReadBoolBytes(bts)
if err != nil {
err = msgp.WrapError(err, "isOldInput")
return
}
default:
bts, err = msgp.Skip(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z *redirectionMsg) Msgsize() (s int) {
s = 1 + 4 + msgp.StringPrefixSize + len(z.key) + 6 + msgp.StringPrefixSize + len(z.value) + 6 + msgp.Uint8Size + 11 + msgp.BoolSize
return
}
// DecodeMsg implements msgp.Decodable
func (z *redirectionMsgs) DecodeMsg(dc *msgp.Reader) (err error) {
var zb0002 uint32
zb0002, err = dc.ReadArrayHeader()
if err != nil {
err = msgp.WrapError(err)
return
}
if zb0002 > zc920acdalimitArrays {
err = msgp.ErrLimitExceeded
return
}
if cap((*z)) >= int(zb0002) {
(*z) = (*z)[:zb0002]
} else {
(*z) = make(redirectionMsgs, zb0002)
}
for zb0001 := range *z {
err = (*z)[zb0001].DecodeMsg(dc)
if err != nil {
err = msgp.WrapError(err, zb0001)
return
}
}
return
}
// EncodeMsg implements msgp.Encodable
func (z redirectionMsgs) EncodeMsg(en *msgp.Writer) (err error) {
err = en.WriteArrayHeader(uint32(len(z)))
if err != nil {
err = msgp.WrapError(err)
return
}
if uint32(len(z)) > zc920acdalimitArrays {
err = msgp.ErrLimitExceeded
return
}
for zb0003 := range z {
err = z[zb0003].EncodeMsg(en)
if err != nil {
err = msgp.WrapError(err, zb0003)
return
}
}
return
}
// MarshalMsg implements msgp.Marshaler
func (z redirectionMsgs) MarshalMsg(b []byte) (o []byte, err error) {
o = msgp.Require(b, z.Msgsize())
o = msgp.AppendArrayHeader(o, uint32(len(z)))
if uint32(len(z)) > zc920acdalimitArrays {
return nil, msgp.ErrLimitExceeded
}
for zb0003 := range z {
o, err = z[zb0003].MarshalMsg(o)
if err != nil {
err = msgp.WrapError(err, zb0003)
return
}
}
return
}
// UnmarshalMsg implements msgp.Unmarshaler
func (z *redirectionMsgs) UnmarshalMsg(bts []byte) (o []byte, err error) {
var zb0002 uint32
zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts)
if err != nil {
err = msgp.WrapError(err)
return
}
if zb0002 > zc920acdalimitArrays {
err = msgp.ErrLimitExceeded
return
}
if cap((*z)) >= int(zb0002) {
(*z) = (*z)[:zb0002]
} else {
(*z) = make(redirectionMsgs, zb0002)
}
for zb0001 := range *z {
bts, err = (*z)[zb0001].UnmarshalMsg(bts)
if err != nil {
err = msgp.WrapError(err, zb0001)
return
}
}
o = bts
return
}
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
func (z redirectionMsgs) Msgsize() (s int) {
s = msgp.ArrayHeaderSize
for zb0003 := range z {
s += z[zb0003].Msgsize()
}
return
}
================================================
FILE: redirect_msgp_test.go
================================================
// Code generated by github.com/tinylib/msgp DO NOT EDIT.
package fiber
import (
"bytes"
"testing"
"github.com/tinylib/msgp/msgp"
)
func TestMarshalUnmarshalredirectionMsg(t *testing.T) {
v := redirectionMsg{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgredirectionMsg(b *testing.B) {
v := redirectionMsg{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgredirectionMsg(b *testing.B) {
v := redirectionMsg{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshalredirectionMsg(b *testing.B) {
v := redirectionMsg{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}
func TestEncodeDecoderedirectionMsg(t *testing.T) {
v := redirectionMsg{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
m := v.Msgsize()
if buf.Len() > m {
t.Log("WARNING: TestEncodeDecoderedirectionMsg Msgsize() is inaccurate")
}
vn := redirectionMsg{}
err := msgp.Decode(&buf, &vn)
if err != nil {
t.Error(err)
}
buf.Reset()
msgp.Encode(&buf, &v)
err = msgp.NewReader(&buf).Skip()
if err != nil {
t.Error(err)
}
}
func BenchmarkEncoderedirectionMsg(b *testing.B) {
v := redirectionMsg{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
en := msgp.NewWriter(msgp.Nowhere)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.EncodeMsg(en)
}
en.Flush()
}
func BenchmarkDecoderedirectionMsg(b *testing.B) {
v := redirectionMsg{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
rd := msgp.NewEndlessReader(buf.Bytes(), b)
dc := msgp.NewReader(rd)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := v.DecodeMsg(dc)
if err != nil {
b.Fatal(err)
}
}
}
func TestMarshalUnmarshalredirectionMsgs(t *testing.T) {
v := redirectionMsgs{}
bts, err := v.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
left, err := v.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left)
}
left, err = msgp.Skip(bts)
if err != nil {
t.Fatal(err)
}
if len(left) > 0 {
t.Errorf("%d bytes left over after Skip(): %q", len(left), left)
}
}
func BenchmarkMarshalMsgredirectionMsgs(b *testing.B) {
v := redirectionMsgs{}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.MarshalMsg(nil)
}
}
func BenchmarkAppendMsgredirectionMsgs(b *testing.B) {
v := redirectionMsgs{}
bts := make([]byte, 0, v.Msgsize())
bts, _ = v.MarshalMsg(bts[0:0])
b.SetBytes(int64(len(bts)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
bts, _ = v.MarshalMsg(bts[0:0])
}
}
func BenchmarkUnmarshalredirectionMsgs(b *testing.B) {
v := redirectionMsgs{}
bts, _ := v.MarshalMsg(nil)
b.ReportAllocs()
b.SetBytes(int64(len(bts)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := v.UnmarshalMsg(bts)
if err != nil {
b.Fatal(err)
}
}
}
func TestEncodeDecoderedirectionMsgs(t *testing.T) {
v := redirectionMsgs{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
m := v.Msgsize()
if buf.Len() > m {
t.Log("WARNING: TestEncodeDecoderedirectionMsgs Msgsize() is inaccurate")
}
vn := redirectionMsgs{}
err := msgp.Decode(&buf, &vn)
if err != nil {
t.Error(err)
}
buf.Reset()
msgp.Encode(&buf, &v)
err = msgp.NewReader(&buf).Skip()
if err != nil {
t.Error(err)
}
}
func BenchmarkEncoderedirectionMsgs(b *testing.B) {
v := redirectionMsgs{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
en := msgp.NewWriter(msgp.Nowhere)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
v.EncodeMsg(en)
}
en.Flush()
}
func BenchmarkDecoderedirectionMsgs(b *testing.B) {
v := redirectionMsgs{}
var buf bytes.Buffer
msgp.Encode(&buf, &v)
b.SetBytes(int64(buf.Len()))
rd := msgp.NewEndlessReader(buf.Bytes(), b)
dc := msgp.NewReader(rd)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := v.DecodeMsg(dc)
if err != nil {
b.Fatal(err)
}
}
}
================================================
FILE: redirect_test.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 📝 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"bytes"
"encoding/hex"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func assertFlashCookieCleared(t *testing.T, setCookie string) {
t.Helper()
setCookie = strings.ToLower(setCookie)
require.Contains(t, setCookie, FlashCookieName+"=")
require.True(t, strings.Contains(setCookie, "max-age=0") || strings.Contains(setCookie, "max-age=-1"))
require.Contains(t, setCookie, "path=/")
}
// go test -run Test_Redirect_To
func Test_Redirect_To(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
err := c.Redirect().To("http://default.com")
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "http://default.com", string(c.Response().Header.Peek(HeaderLocation)))
err = c.Redirect().Status(301).To("http://example.com")
require.NoError(t, err)
require.Equal(t, 301, c.Response().StatusCode())
require.Equal(t, "http://example.com", string(c.Response().Header.Peek(HeaderLocation)))
}
// go test -run Test_Redirect_To_WithFlashMessages
func Test_Redirect_To_WithFlashMessages(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.Redirect().With("success", "2").With("success", "1").With("message", "test", 2).To("http://example.com")
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "http://example.com", string(c.Response().Header.Peek(HeaderLocation)))
c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing
var msgs redirectionMsgs
decoded, err := hex.DecodeString(c.Cookies(FlashCookieName))
require.NoError(t, err)
_, err = msgs.UnmarshalMsg(decoded)
require.NoError(t, err)
require.Len(t, msgs, 2)
require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false})
require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 2, isOldInput: false})
}
// go test -run Test_Redirect_Route_WithParams
func Test_Redirect_Route_WithParams(t *testing.T) {
t.Parallel()
app := New()
app.Get("/user/:name", func(c Ctx) error {
return c.JSON(c.Params("name"))
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.Redirect().Route("user", RedirectConfig{
Params: Map{
"name": "fiber",
},
})
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation)))
}
// go test -run Test_Redirect_Route_WithParams_WithQueries
func Test_Redirect_Route_WithParams_WithQueries(t *testing.T) {
t.Parallel()
app := New()
app.Get("/user/:name", func(c Ctx) error {
return c.JSON(c.Params("name"))
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.Redirect().Route("user", RedirectConfig{
Params: Map{
"name": "fiber",
},
Queries: map[string]string{"data[0][name]": "john", "data[0][age]": "10", "test": "doe"},
})
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
// analysis of query parameters with url parsing, since a map pass is always randomly ordered
location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation)))
require.NoError(t, err, "url.Parse(location)")
require.Equal(t, "/user/fiber", location.Path)
require.Equal(t, url.Values{"data[0][name]": []string{"john"}, "data[0][age]": []string{"10"}, "test": []string{"doe"}}, location.Query())
}
// go test -run Test_Redirect_Route_WithOptionalParams
func Test_Redirect_Route_WithOptionalParams(t *testing.T) {
t.Parallel()
app := New()
app.Get("/user/:name?", func(c Ctx) error {
return c.JSON(c.Params("name"))
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.Redirect().Route("user", RedirectConfig{
Params: Map{
"name": "fiber",
},
})
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation)))
}
// go test -run Test_Redirect_Route_WithOptionalParamsWithoutValue
func Test_Redirect_Route_WithOptionalParamsWithoutValue(t *testing.T) {
t.Parallel()
app := New()
app.Get("/user/:name?", func(c Ctx) error {
return c.JSON(c.Params("name"))
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.Redirect().Route("user")
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "/user/", string(c.Response().Header.Peek(HeaderLocation)))
}
// go test -run Test_Redirect_Route_WithGreedyParameters
func Test_Redirect_Route_WithGreedyParameters(t *testing.T) {
t.Parallel()
app := New()
app.Get("/user/+", func(c Ctx) error {
return c.JSON(c.Params("+"))
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.Redirect().Route("user", RedirectConfig{
Params: Map{
"+": "test/routes",
},
})
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "/user/test/routes", string(c.Response().Header.Peek(HeaderLocation)))
}
// go test -run Test_Redirect_Back
func Test_Redirect_Back(t *testing.T) {
t.Parallel()
app := New()
app.Get("/", func(c Ctx) error {
return c.JSON("Home")
}).Name("home")
c := app.AcquireCtx(&fasthttp.RequestCtx{})
err := c.Redirect().Back("/")
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "/", string(c.Response().Header.Peek(HeaderLocation)))
err = c.Redirect().Back()
require.Equal(t, 500, c.Response().StatusCode())
require.ErrorAs(t, err, &ErrRedirectBackNoFallback)
}
// go test -run Test_Redirect_Back_WithFlashMessages
func Test_Redirect_Back_WithFlashMessages(t *testing.T) {
t.Parallel()
app := New()
app.Get("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
err := c.Redirect().With("success", "1").With("message", "test").Back("/")
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "/", string(c.Response().Header.Peek(HeaderLocation)))
c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing
var msgs redirectionMsgs
decoded, err := hex.DecodeString(c.Cookies(FlashCookieName))
require.NoError(t, err)
_, err = msgs.UnmarshalMsg(decoded)
require.NoError(t, err)
require.Len(t, msgs, 2)
require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false})
require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false})
}
// go test -run Test_Redirect_Back_WithReferer
func Test_Redirect_Back_WithReferer(t *testing.T) {
t.Parallel()
app := New()
app.Get("/", func(c Ctx) error {
return c.JSON("Home")
}).Name("home")
app.Get("/back", func(c Ctx) error {
return c.JSON("Back")
}).Name("back")
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Header.Set(HeaderReferer, "/back")
err := c.Redirect().Back("/")
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "/back", c.Get(HeaderReferer))
require.Equal(t, "/back", string(c.Response().Header.Peek(HeaderLocation)))
}
// go test -run Test_Redirect_Route_WithFlashMessages
func Test_Redirect_Route_WithFlashMessages(t *testing.T) {
t.Parallel()
app := New()
app.Get("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
err := c.Redirect().With("success", "1").With("message", "test").Route("user")
require.Contains(t, c.redirect.messages, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false})
require.Contains(t, c.redirect.messages, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false})
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation)))
c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing
var msgs redirectionMsgs
decoded, err := hex.DecodeString(c.Cookies(FlashCookieName))
require.NoError(t, err)
_, err = msgs.UnmarshalMsg(decoded)
require.NoError(t, err)
require.Len(t, msgs, 2)
require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false})
require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false})
}
// go test -run Test_Redirect_Route_WithOldInput
func Test_Redirect_Route_WithOldInput(t *testing.T) {
t.Parallel()
t.Run("Query", func(t *testing.T) {
t.Parallel()
app := New()
app.Get("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().URI().SetQueryString("id=1&name=tom")
err := c.Redirect().With("success", "1").With("message", "test").WithInput().Route("user")
require.Contains(t, c.redirect.messages, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false})
require.Contains(t, c.redirect.messages, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false})
require.Contains(t, c.redirect.messages, redirectionMsg{key: "id", value: "1", isOldInput: true})
require.Contains(t, c.redirect.messages, redirectionMsg{key: "name", value: "tom", isOldInput: true})
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation)))
c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing
var msgs redirectionMsgs
decoded, err := hex.DecodeString(c.Cookies(FlashCookieName))
require.NoError(t, err)
_, err = msgs.UnmarshalMsg(decoded)
require.NoError(t, err)
require.Len(t, msgs, 4)
require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false})
require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false})
require.Contains(t, msgs, redirectionMsg{key: "id", value: "1", level: 0, isOldInput: true})
require.Contains(t, msgs, redirectionMsg{key: "name", value: "tom", level: 0, isOldInput: true})
})
t.Run("Form", func(t *testing.T) {
t.Parallel()
app := New()
app.Post("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Request().Header.Set(HeaderContentType, MIMEApplicationForm)
c.Request().SetBodyString("id=1&name=tom")
err := c.Redirect().With("success", "1").With("message", "test").WithInput().Route("user")
require.Contains(t, c.redirect.messages, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false})
require.Contains(t, c.redirect.messages, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false})
require.Contains(t, c.redirect.messages, redirectionMsg{key: "id", value: "1", isOldInput: true})
require.Contains(t, c.redirect.messages, redirectionMsg{key: "name", value: "tom", isOldInput: true})
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation)))
c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing
var msgs redirectionMsgs
decoded, err := hex.DecodeString(c.Cookies(FlashCookieName))
require.NoError(t, err)
_, err = msgs.UnmarshalMsg(decoded)
require.NoError(t, err)
require.Len(t, msgs, 4)
require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false})
require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false})
require.Contains(t, msgs, redirectionMsg{key: "id", value: "1", level: 0, isOldInput: true})
require.Contains(t, msgs, redirectionMsg{key: "name", value: "tom", level: 0, isOldInput: true})
})
t.Run("MultipartForm", func(t *testing.T) {
t.Parallel()
app := New()
app.Get("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
require.NoError(t, writer.WriteField("id", "1"))
require.NoError(t, writer.WriteField("name", "tom"))
require.NoError(t, writer.Close())
c.Request().SetBody(body.Bytes())
c.Request().Header.Set(HeaderContentType, writer.FormDataContentType())
err := c.Redirect().With("success", "1").With("message", "test").WithInput().Route("user")
require.Contains(t, c.redirect.messages, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false})
require.Contains(t, c.redirect.messages, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false})
require.Contains(t, c.redirect.messages, redirectionMsg{key: "id", value: "1", isOldInput: true})
require.Contains(t, c.redirect.messages, redirectionMsg{key: "name", value: "tom", isOldInput: true})
require.NoError(t, err)
require.Equal(t, StatusSeeOther, c.Response().StatusCode())
require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation)))
c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing
var msgs redirectionMsgs
decoded, err := hex.DecodeString(c.Cookies(FlashCookieName))
require.NoError(t, err)
_, err = msgs.UnmarshalMsg(decoded)
require.NoError(t, err)
require.Len(t, msgs, 4)
require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false})
require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false})
require.Contains(t, msgs, redirectionMsg{key: "id", value: "1", level: 0, isOldInput: true})
require.Contains(t, msgs, redirectionMsg{key: "name", value: "tom", level: 0, isOldInput: true})
})
}
func Test_Redirect_WithInput_ReusesClearedMap(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
defer app.ReleaseCtx(c)
c.Request().URI().SetQueryString("first=1")
c.Redirect().WithInput()
require.Contains(t, c.redirect.messages, redirectionMsg{key: "first", value: "1", isOldInput: true})
c.redirect.messages = c.redirect.messages[:0]
c.Request().URI().SetQueryString("second=2")
c.Redirect().WithInput()
require.Len(t, c.redirect.messages, 1)
require.Contains(t, c.redirect.messages, redirectionMsg{key: "second", value: "2", isOldInput: true})
require.NotContains(t, c.redirect.messages, redirectionMsg{key: "first", value: "1", isOldInput: true})
}
// go test -run Test_Redirect_parseAndClearFlashMessages
func Test_Redirect_parseAndClearFlashMessages(t *testing.T) {
t.Parallel()
app := New()
app.Get("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
msgs := redirectionMsgs{
{
key: "success",
value: "1",
},
{
key: "message",
value: "test",
},
{
key: "name",
value: "tom",
isOldInput: true,
},
{
key: "id",
value: "1",
isOldInput: true,
},
}
val, err := msgs.MarshalMsg(nil)
require.NoError(t, err)
c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val))
c.Redirect().parseAndClearFlashMessages()
require.Equal(t, FlashMessage{
Key: "success",
Value: "1",
Level: 0,
}, c.Redirect().Message("success"))
require.Equal(t, FlashMessage{
Key: "message",
Value: "test",
Level: 0,
}, c.Redirect().Message("message"))
require.Equal(t, FlashMessage{}, c.Redirect().Message("success"))
require.Equal(t, FlashMessage{}, c.Redirect().Message("not_message"))
require.Empty(t, c.Redirect().Messages())
require.Equal(t, OldInputData{
Key: "id",
Value: "1",
}, c.Redirect().OldInput("id"))
require.Equal(t, OldInputData{
Key: "name",
Value: "tom",
}, c.Redirect().OldInput("name"))
require.Equal(t, OldInputData{}, c.Redirect().OldInput("not_name"))
require.Equal(t, []OldInputData{
{
Key: "name",
Value: "tom",
},
{
Key: "id",
Value: "1",
},
}, c.Redirect().OldInputs())
assertFlashCookieCleared(t, string(c.Response().Header.Peek(HeaderSetCookie)))
c.Request().Header.Set(HeaderCookie, "fiber_flash=test")
c.Redirect().parseAndClearFlashMessages()
require.Empty(t, c.Redirect().messages)
}
// Test_Redirect_parseAndClearFlashMessages_InvalidHex tests the case where hex decoding fails
func Test_Redirect_parseAndClearFlashMessages_InvalidHex(t *testing.T) {
t.Parallel()
app := New()
// Setup request and response
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
defer app.ReleaseCtx(c)
// Create redirect instance
r := AcquireRedirect()
r.c = c
// Set invalid hex value in flash cookie
c.Request().Header.SetCookie(FlashCookieName, "not-a-valid-hex-string")
// Call parseAndClearFlashMessages
r.parseAndClearFlashMessages()
// Verify that no flash messages are processed (should be empty)
require.Empty(t, r.messages)
// Release redirect
ReleaseRedirect(r)
}
func Test_Redirect_Messages_ClearsFlashMessages(t *testing.T) {
t.Parallel()
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
defer app.ReleaseCtx(c)
val, err := testredirectionMsgs.MarshalMsg(nil)
require.NoError(t, err)
c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val))
c.Redirect().parseAndClearFlashMessages()
require.Equal(t, []FlashMessage{
{
Key: "success",
Value: "1",
Level: 0,
},
{
Key: "message",
Value: "test",
Level: 0,
},
}, c.Redirect().Messages())
require.Empty(t, c.Redirect().Messages())
require.Equal(t, FlashMessage{}, c.Redirect().Message("success"))
require.Equal(t, []OldInputData{
{
Key: "name",
Value: "tom",
},
{
Key: "id",
Value: "1",
},
}, c.Redirect().OldInputs())
}
func Test_Redirect_CompleteFlowWithFlashMessages(t *testing.T) {
t.Parallel()
app := New()
// First handler that sets flash messages and redirects
app.Get("/source", func(c Ctx) error {
// Redirect to the target handler
return c.Redirect().With("string_message", "Hello, World!").
With("number_message", "12345").
With("bool_message", "true").
To("/target")
})
// Second handler that receives and processes flash messages
app.Get("/target", func(c Ctx) error {
// Get all flash messages and return them as a JSON response
return c.JSON(Map{
"string_message": c.Redirect().Message("string_message").Value,
"number_message": c.Redirect().Message("number_message").Value,
"bool_message": c.Redirect().Message("bool_message").Value,
})
})
// Step 1: Make the initial request to the source route
req := httptest.NewRequest(MethodGet, "/source", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, StatusSeeOther, resp.StatusCode)
require.Equal(t, "/target", resp.Header.Get(HeaderLocation))
// Verify and get the cookie from the response
cookies := resp.Cookies()
var flashCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "fiber_flash" {
flashCookie = cookie
break
}
}
require.NotNil(t, flashCookie, "Flash cookie should be set")
// Step 2: Make the second request to the target route with the cookie
req = httptest.NewRequest(MethodGet, "/target", http.NoBody)
req.Header.Set("Cookie", flashCookie.Name+"="+flashCookie.Value)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, StatusOK, resp.StatusCode)
assertFlashCookieCleared(t, resp.Header.Get(HeaderSetCookie))
// Parse the JSON response and verify flash messages
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var result map[string]any
err = json.Unmarshal(body, &result)
require.NoError(t, err)
// Verify all flash messages were received correctly
require.Equal(t, "Hello, World!", result["string_message"])
require.Equal(t, "12345", result["number_message"]) // JSON numbers are float64
require.Equal(t, "true", result["bool_message"])
}
func Test_Redirect_FlashMessagesWithSpecialChars(t *testing.T) {
t.Parallel()
app := New()
// Handler that sets flash messages with special characters and redirects
app.Get("/special-source", func(c Ctx) error {
// Create a large message to test encoding of larger data
return c.Redirect().With("null_bytes", "Contains\x00null\x00bytes").
With("control_chars", "Contains\r\ncontrol\tcharacters").
With("unicode", "Unicode: 你好世界").
With("emoji", "Emoji: 🔥🚀😊").
To("/special-target")
})
// Target handler that receives the flash messages
app.Get("/special-target", func(c Ctx) error {
return c.JSON(Map{
"null_bytes": c.Redirect().Message("null_bytes").Value,
"control_chars": c.Redirect().Message("control_chars").Value,
"unicode": c.Redirect().Message("unicode").Value,
"emoji": c.Redirect().Message("emoji").Value,
})
})
// Step 1: Make the initial request
req := httptest.NewRequest(MethodGet, "/special-source", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, StatusSeeOther, resp.StatusCode)
require.Equal(t, "/special-target", resp.Header.Get(HeaderLocation))
// Get the flash cookie
var flashCookie *http.Cookie
for _, cookie := range resp.Cookies() {
if cookie.Name == "fiber_flash" {
flashCookie = cookie
break
}
}
require.NotNil(t, flashCookie, "Flash cookie should be set")
// Step 2: Make the second request with the cookie
req = httptest.NewRequest(MethodGet, "/special-target", http.NoBody)
req.Header.Set("Cookie", flashCookie.Name+"="+flashCookie.Value)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, StatusOK, resp.StatusCode)
assertFlashCookieCleared(t, resp.Header.Get(HeaderSetCookie))
// Parse and verify the response
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var result map[string]any
err = json.Unmarshal(body, &result)
require.NoError(t, err)
// Verify special character handling
require.Equal(t, "Contains\x00null\x00bytes", result["null_bytes"])
require.Equal(t, "Contains\r\ncontrol\tcharacters", result["control_chars"])
require.Equal(t, "Unicode: 你好世界", result["unicode"])
require.Equal(t, "Emoji: 🔥🚀😊", result["emoji"])
}
// go test -v -run=^$ -bench=Benchmark_Redirect_Route -benchmem -count=4
func Benchmark_Redirect_Route(b *testing.B) {
app := New()
app.Get("/user/:name", func(c Ctx) error {
return c.JSON(c.Params("name"))
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
var err error
for b.Loop() {
err = c.Redirect().Route("user", RedirectConfig{
Params: Map{
"name": "fiber",
},
})
}
require.NoError(b, err)
require.Equal(b, StatusSeeOther, c.Response().StatusCode())
require.Equal(b, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation)))
}
// go test -v -run=^$ -bench=Benchmark_Redirect_Route_WithQueries -benchmem -count=4
func Benchmark_Redirect_Route_WithQueries(b *testing.B) {
app := New()
app.Get("/user/:name", func(c Ctx) error {
return c.JSON(c.Params("name"))
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
var err error
for b.Loop() {
err = c.Redirect().Route("user", RedirectConfig{
Params: Map{
"name": "fiber",
},
Queries: map[string]string{"a": "a", "b": "b"},
})
}
require.NoError(b, err)
require.Equal(b, StatusSeeOther, c.Response().StatusCode())
// analysis of query parameters with url parsing, since a map pass is always randomly ordered
location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation)))
require.NoError(b, err, "url.Parse(location)")
require.Equal(b, "/user/fiber", location.Path)
require.Equal(b, url.Values{"a": []string{"a"}, "b": []string{"b"}}, location.Query())
}
// go test -v -run=^$ -bench=Benchmark_Redirect_Route_WithFlashMessages -benchmem -count=4
func Benchmark_Redirect_Route_WithFlashMessages(b *testing.B) {
app := New()
app.Get("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
b.ReportAllocs()
var err error
for b.Loop() {
err = c.Redirect().With("success", "1").With("message", "test").Route("user")
}
require.NoError(b, err)
require.Equal(b, StatusSeeOther, c.Response().StatusCode())
require.Equal(b, "/user", string(c.Response().Header.Peek(HeaderLocation)))
c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing
var msgs redirectionMsgs
decoded, err := hex.DecodeString(c.Cookies(FlashCookieName))
require.NoError(b, err)
_, err = msgs.UnmarshalMsg(decoded)
require.NoError(b, err)
require.Contains(b, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false})
require.Contains(b, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false})
}
var testredirectionMsgs = redirectionMsgs{
{
key: "success",
value: "1",
},
{
key: "message",
value: "test",
},
{
key: "name",
value: "tom",
isOldInput: true,
},
{
key: "id",
value: "1",
isOldInput: true,
},
}
// go test -v -run=^$ -bench=Benchmark_Redirect_parseAndClearFlashMessages -benchmem -count=4
func Benchmark_Redirect_parseAndClearFlashMessages(b *testing.B) {
app := New()
app.Get("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
val, err := testredirectionMsgs.MarshalMsg(nil)
require.NoError(b, err)
c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val))
b.ReportAllocs()
for b.Loop() {
c.Redirect().parseAndClearFlashMessages()
}
require.Equal(b, FlashMessage{
Key: "success",
Value: "1",
}, c.Redirect().Message("success"))
require.Equal(b, FlashMessage{
Key: "message",
Value: "test",
}, c.Redirect().Message("message"))
require.Equal(b, OldInputData{
Key: "id",
Value: "1",
}, c.Redirect().OldInput("id"))
require.Equal(b, OldInputData{
Key: "name",
Value: "tom",
}, c.Redirect().OldInput("name"))
}
// go test -v -run=^$ -bench=Benchmark_Redirect_processFlashMessages -benchmem -count=4
func Benchmark_Redirect_processFlashMessages(b *testing.B) {
app := New()
app.Get("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
c.Redirect().With("success", "1").With("message", "test")
b.ReportAllocs()
for b.Loop() {
c.Redirect().processFlashMessages()
}
c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing
var msgs redirectionMsgs
decoded, err := hex.DecodeString(c.Cookies(FlashCookieName))
require.NoError(b, err)
_, err = msgs.UnmarshalMsg(decoded)
require.NoError(b, err)
require.Len(b, msgs, 2)
require.Contains(b, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false})
require.Contains(b, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false})
}
// go test -v -run=^$ -bench=Benchmark_Redirect_Messages -benchmem -count=4
func Benchmark_Redirect_Messages(b *testing.B) {
app := New()
app.Get("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
val, err := testredirectionMsgs.MarshalMsg(nil)
require.NoError(b, err)
c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val))
c.Redirect().parseAndClearFlashMessages()
var msgs []FlashMessage
msgTemplate := testredirectionMsgs
b.ReportAllocs()
for b.Loop() {
c.flashMessages = c.flashMessages[:0]
c.flashMessages = append(c.flashMessages, msgTemplate...)
msgs = c.Redirect().Messages()
}
require.Contains(b, msgs, FlashMessage{
Key: "success",
Value: "1",
Level: 0,
})
require.Contains(b, msgs, FlashMessage{
Key: "message",
Value: "test",
Level: 0,
})
}
// go test -v -run=^$ -bench=Benchmark_Redirect_OldInputs -benchmem -count=4
func Benchmark_Redirect_OldInputs(b *testing.B) {
app := New()
app.Get("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
val, err := testredirectionMsgs.MarshalMsg(nil)
require.NoError(b, err)
c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val))
c.Redirect().parseAndClearFlashMessages()
var oldInputs []OldInputData
b.ReportAllocs()
for b.Loop() {
oldInputs = c.Redirect().OldInputs()
}
require.Contains(b, oldInputs, OldInputData{
Key: "name",
Value: "tom",
})
require.Contains(b, oldInputs, OldInputData{
Key: "id",
Value: "1",
})
}
// go test -v -run=^$ -bench=Benchmark_Redirect_Message -benchmem -count=4
func Benchmark_Redirect_Message(b *testing.B) {
app := New()
app.Get("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
val, err := testredirectionMsgs.MarshalMsg(nil)
require.NoError(b, err)
c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val))
c.Redirect().parseAndClearFlashMessages()
var msg FlashMessage
msgTemplate := testredirectionMsgs
b.ReportAllocs()
for b.Loop() {
c.flashMessages = c.flashMessages[:0]
c.flashMessages = append(c.flashMessages, msgTemplate...)
msg = c.Redirect().Message("message")
}
require.Equal(b, FlashMessage{
Key: "message",
Value: "test",
Level: 0,
}, msg)
}
// go test -v -run=^$ -bench=Benchmark_Redirect_OldInput -benchmem -count=4
func Benchmark_Redirect_OldInput(b *testing.B) {
app := New()
app.Get("/user", func(c Ctx) error {
return c.SendString("user")
}).Name("user")
c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
val, err := testredirectionMsgs.MarshalMsg(nil)
require.NoError(b, err)
c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val))
c.Redirect().parseAndClearFlashMessages()
var input OldInputData
b.ReportAllocs()
for b.Loop() {
input = c.Redirect().OldInput("name")
}
require.Equal(b, OldInputData{
Key: "name",
Value: "tom",
}, input)
}
================================================
FILE: register.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
// Register defines all router handle interface generate by RouteChain().
type Register interface {
All(handler any, handlers ...any) Register
Get(handler any, handlers ...any) Register
Head(handler any, handlers ...any) Register
Post(handler any, handlers ...any) Register
Put(handler any, handlers ...any) Register
Delete(handler any, handlers ...any) Register
Connect(handler any, handlers ...any) Register
Options(handler any, handlers ...any) Register
Trace(handler any, handlers ...any) Register
Patch(handler any, handlers ...any) Register
Add(methods []string, handler any, handlers ...any) Register
RouteChain(path string) Register
}
var _ Register = (*Registering)(nil)
// Registering provides route registration helpers for a specific path on the
// application instance.
type Registering struct {
app *App
group *Group
path string
}
// All registers a middleware route that will match requests
// with the provided path which is stored in register struct.
//
// app.RouteChain("/").All(func(c fiber.Ctx) error {
// return c.Next()
// })
// app.RouteChain("/api").All(func(c fiber.Ctx) error {
// return c.Next()
// })
// app.RouteChain("/api").All(handler, func(c fiber.Ctx) error {
// return c.Next()
// })
//
// This method will match all HTTP verbs: GET, POST, PUT, HEAD etc...
func (r *Registering) All(handler any, handlers ...any) Register {
converted := collectHandlers("register", append([]any{handler}, handlers...)...)
r.app.register([]string{methodUse}, r.path, r.group, converted...)
return r
}
// Get registers a route for GET methods that requests a representation
// of the specified resource. Requests using GET should only retrieve data.
func (r *Registering) Get(handler any, handlers ...any) Register {
return r.Add([]string{MethodGet}, handler, handlers...)
}
// Head registers a route for HEAD methods that asks for a response identical
// to that of a GET request, but without the response body.
func (r *Registering) Head(handler any, handlers ...any) Register {
return r.Add([]string{MethodHead}, handler, handlers...)
}
// Post registers a route for POST methods that is used to submit an entity to the
// specified resource, often causing a change in state or side effects on the server.
func (r *Registering) Post(handler any, handlers ...any) Register {
return r.Add([]string{MethodPost}, handler, handlers...)
}
// Put registers a route for PUT methods that replaces all current representations
// of the target resource with the request payload.
func (r *Registering) Put(handler any, handlers ...any) Register {
return r.Add([]string{MethodPut}, handler, handlers...)
}
// Delete registers a route for DELETE methods that deletes the specified resource.
func (r *Registering) Delete(handler any, handlers ...any) Register {
return r.Add([]string{MethodDelete}, handler, handlers...)
}
// Connect registers a route for CONNECT methods that establishes a tunnel to the
// server identified by the target resource.
func (r *Registering) Connect(handler any, handlers ...any) Register {
return r.Add([]string{MethodConnect}, handler, handlers...)
}
// Options registers a route for OPTIONS methods that is used to describe the
// communication options for the target resource.
func (r *Registering) Options(handler any, handlers ...any) Register {
return r.Add([]string{MethodOptions}, handler, handlers...)
}
// Trace registers a route for TRACE methods that performs a message loop-back
// test along the r.Path to the target resource.
func (r *Registering) Trace(handler any, handlers ...any) Register {
return r.Add([]string{MethodTrace}, handler, handlers...)
}
// Patch registers a route for PATCH methods that is used to apply partial
// modifications to a resource.
func (r *Registering) Patch(handler any, handlers ...any) Register {
return r.Add([]string{MethodPatch}, handler, handlers...)
}
// Add allows you to specify multiple HTTP methods to register a route.
// The provided handlers are executed in order, starting with `handler` and then the variadic `handlers`.
func (r *Registering) Add(methods []string, handler any, handlers ...any) Register {
converted := collectHandlers("register", append([]any{handler}, handlers...)...)
r.app.register(methods, r.path, r.group, converted...)
return r
}
// RouteChain returns a new Register instance whose route path takes
// the path in the current instance as its prefix.
func (r *Registering) RouteChain(path string) Register {
// Create new group
route := &Registering{app: r.app, group: r.group, path: getGroupPath(r.path, path)}
return route
}
================================================
FILE: req.go
================================================
package fiber
import (
"bytes"
"errors"
"fmt"
"math"
"mime/multipart"
"net"
"strconv"
"strings"
"github.com/gofiber/utils/v2"
utilsbytes "github.com/gofiber/utils/v2/bytes"
utilsstrings "github.com/gofiber/utils/v2/strings"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
"golang.org/x/net/idna"
)
// Pre-allocated byte slices for common header comparisons to avoid allocations
var (
xForwardedPrefix = []byte("X-Forwarded-")
xForwardedProtoBytes = []byte(HeaderXForwardedProto)
xForwardedProtocolBytes = []byte(HeaderXForwardedProtocol)
xForwardedSslBytes = []byte(HeaderXForwardedSsl)
xURLSchemeBytes = []byte(HeaderXUrlScheme)
onBytes = []byte("on")
)
// Range represents the parsed HTTP Range header extracted by DefaultReq.Range.
type Range struct {
Type string
Ranges []RangeSet
}
// RangeSet represents a single content range from a request.
type RangeSet struct {
Start int64
End int64
}
// DefaultReq is the default implementation of Req used by DefaultCtx.
//
//go:generate ifacemaker --file req.go --struct DefaultReq --iface Req --pkg fiber --output req_interface_gen.go --not-exported true --iface-comment "Req is an interface for request-related Ctx methods."
type DefaultReq struct {
c *DefaultCtx
}
// Accepts checks if the specified extensions or content types are acceptable.
func (r *DefaultReq) Accepts(offers ...string) string {
header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAccept))
return getOffer(header, acceptsOfferType, offers...)
}
// AcceptsCharsets checks if the specified charset is acceptable.
func (r *DefaultReq) AcceptsCharsets(offers ...string) string {
header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptCharset))
return getOffer(header, acceptsOffer, offers...)
}
// AcceptsEncodings checks if the specified encoding is acceptable.
func (r *DefaultReq) AcceptsEncodings(offers ...string) string {
header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptEncoding))
return getOffer(header, acceptsOffer, offers...)
}
// AcceptsLanguages checks if the specified language is acceptable using
// RFC 4647 Basic Filtering.
func (r *DefaultReq) AcceptsLanguages(offers ...string) string {
header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptLanguage))
return getOffer(header, acceptsLanguageOfferBasic, offers...)
}
// AcceptsLanguagesExtended checks if the specified language is acceptable using
// RFC 4647 Extended Filtering.
func (r *DefaultReq) AcceptsLanguagesExtended(offers ...string) string {
header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptLanguage))
return getOffer(header, acceptsLanguageOfferExtended, offers...)
}
// App returns the *App reference to the instance of the Fiber application
func (r *DefaultReq) App() *App {
return r.c.app
}
// BaseURL returns (protocol + host + base path).
func (r *DefaultReq) BaseURL() string {
return r.c.BaseURL()
}
// BodyRaw contains the raw body submitted in a POST request.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
func (r *DefaultReq) BodyRaw() []byte {
return r.getBody()
}
//nolint:nonamedreturns // gocritic unnamedResult prefers naming decoded body, decode count, and error
func (r *DefaultReq) tryDecodeBodyInOrder(
originalBody *[]byte,
encodings []string,
) (body []byte, decodesRealized uint8, err error) {
request := &r.c.fasthttp.Request
for idx := range encodings {
i := len(encodings) - 1 - idx
encoding := encodings[i]
decodesRealized++
var decodeErr error
switch encoding {
case StrGzip, "x-gzip":
body, decodeErr = request.BodyGunzip()
case StrBr, StrBrotli:
body, decodeErr = request.BodyUnbrotli()
case StrDeflate:
body, decodeErr = request.BodyInflate()
case StrZstd:
body, decodeErr = request.BodyUnzstd()
case StrIdentity:
body = request.Body()
case StrCompress, "x-compress":
return nil, decodesRealized - 1, ErrNotImplemented
default:
return nil, decodesRealized - 1, ErrUnsupportedMediaType
}
if decodeErr != nil {
return nil, decodesRealized, decodeErr
}
if i > 0 && decodesRealized > 0 {
if i == len(encodings)-1 {
tempBody := request.Body()
*originalBody = make([]byte, len(tempBody))
copy(*originalBody, tempBody)
}
request.SetBodyRaw(body)
}
}
return body, decodesRealized, nil
}
// Body contains the raw body submitted in a POST request.
// This method will decompress the body if the 'Content-Encoding' header is provided.
// It returns the original (or decompressed) body data which is valid only within the handler.
// Don't store direct references to the returned data.
// If you need to keep the body's data later, make a copy or use the Immutable option.
func (r *DefaultReq) Body() []byte {
var (
err error
body, originalBody []byte
headerEncoding string
encodingOrder = []string{"", "", ""}
)
request := &r.c.fasthttp.Request
// Get Content-Encoding header
headerEncoding = utils.UnsafeString(utilsbytes.UnsafeToLower(request.Header.ContentEncoding()))
// If no encoding is provided, return the original body
if headerEncoding == "" {
return r.getBody()
}
// Split and get the encodings list, in order to attend the
// rule defined at: https://www.rfc-editor.org/rfc/rfc9110#section-8.4-5
encodingOrder = getSplicedStrList(headerEncoding, encodingOrder)
for i := range encodingOrder {
encodingOrder[i] = utilsstrings.UnsafeToLower(encodingOrder[i])
}
if len(encodingOrder) == 0 {
return r.getBody()
}
var decodesRealized uint8
body, decodesRealized, err = r.tryDecodeBodyInOrder(&originalBody, encodingOrder)
// Ensure that the body will be the original
if originalBody != nil && decodesRealized > 0 {
request.SetBodyRaw(originalBody)
}
if err != nil {
switch {
case errors.Is(err, ErrUnsupportedMediaType):
_ = r.c.DefaultRes.SendStatus(StatusUnsupportedMediaType) //nolint:errcheck,staticcheck // It is fine to ignore the error and the static check
case errors.Is(err, ErrNotImplemented):
_ = r.c.DefaultRes.SendStatus(StatusNotImplemented) //nolint:errcheck,staticcheck // It is fine to ignore the error and the static check
default:
// do nothing
}
return []byte(err.Error())
}
return r.c.app.GetBytes(body)
}
// RequestCtx returns *fasthttp.RequestCtx that carries a deadline
// a cancellation signal, and other values across API boundaries.
func (r *DefaultReq) RequestCtx() *fasthttp.RequestCtx {
return r.c.fasthttp
}
// FullURL returns the full request URL (protocol + host + original URL).
func (c *DefaultCtx) FullURL() string {
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
buf.WriteString(c.Scheme())
buf.WriteString("://")
buf.WriteString(c.Host())
buf.WriteString(c.OriginalURL())
return c.app.toString(buf.Bytes())
}
// UserAgent returns the User-Agent request header.
func (c *DefaultCtx) UserAgent() string {
return c.app.toString(c.fasthttp.Request.Header.UserAgent())
}
// Referer returns the Referer request header.
func (c *DefaultCtx) Referer() string {
return c.app.toString(c.fasthttp.Request.Header.Referer())
}
// AcceptLanguage returns the Accept-Language request header.
func (c *DefaultCtx) AcceptLanguage() string {
return c.app.toString(c.fasthttp.Request.Header.Peek(HeaderAcceptLanguage))
}
// AcceptEncoding returns the Accept-Encoding request header.
func (c *DefaultCtx) AcceptEncoding() string {
return c.app.toString(c.fasthttp.Request.Header.Peek(HeaderAcceptEncoding))
}
// HasHeader reports whether the request includes a header with the given key.
func (c *DefaultCtx) HasHeader(key string) bool {
return len(c.fasthttp.Request.Header.Peek(key)) > 0
}
// MediaType returns the MIME type from the Content-Type header without parameters.
func (c *DefaultCtx) MediaType() string {
contentType := utils.TrimSpace(c.fasthttp.Request.Header.ContentType())
if len(contentType) == 0 {
return ""
}
if idx := bytes.IndexByte(contentType, ';'); idx != -1 {
contentType = contentType[:idx]
}
contentType = utils.TrimSpace(contentType)
return c.app.toString(contentType)
}
// Charset returns the charset parameter from the Content-Type header.
func (c *DefaultCtx) Charset() string {
contentType := c.fasthttp.Request.Header.ContentType()
if len(contentType) == 0 {
return ""
}
_, after, ok := bytes.Cut(contentType, []byte{';'})
if !ok {
return ""
}
params := after
for len(params) > 0 {
params = utils.TrimSpace(params)
if len(params) == 0 {
return ""
}
param := params
if idx := bytes.IndexByte(params, ';'); idx != -1 {
param = params[:idx]
params = params[idx+1:]
} else {
params = nil
}
param = utils.TrimSpace(param)
if len(param) == 0 {
continue
}
before, after, ok := bytes.Cut(param, []byte{'='})
if !ok {
continue
}
name := utils.TrimSpace(before)
if !bytes.EqualFold(name, []byte("charset")) {
continue
}
value := utils.TrimSpace(after)
if len(value) >= 2 && value[0] == '"' && value[len(value)-1] == '"' {
value = value[1 : len(value)-1]
}
return c.app.toString(value)
}
return ""
}
// IsJSON reports whether the Content-Type header is JSON.
func (c *DefaultCtx) IsJSON() bool {
return utils.EqualFold(c.MediaType(), MIMEApplicationJSON)
}
// IsForm reports whether the Content-Type header is form-encoded.
func (c *DefaultCtx) IsForm() bool {
return utils.EqualFold(c.MediaType(), MIMEApplicationForm)
}
// IsMultipart reports whether the Content-Type header is multipart form data.
func (c *DefaultCtx) IsMultipart() bool {
return utils.EqualFold(c.MediaType(), MIMEMultipartForm)
}
// AcceptsJSON reports whether the Accept header allows JSON.
func (c *DefaultCtx) AcceptsJSON() bool {
return c.Accepts(MIMEApplicationJSON) != ""
}
// AcceptsHTML reports whether the Accept header allows HTML.
func (c *DefaultCtx) AcceptsHTML() bool {
return c.Accepts(MIMETextHTML) != ""
}
// AcceptsXML reports whether the Accept header allows XML.
func (c *DefaultCtx) AcceptsXML() bool {
return c.Accepts(MIMEApplicationXML, MIMETextXML) != ""
}
// AcceptsEventStream reports whether the Accept header allows text/event-stream.
func (c *DefaultCtx) AcceptsEventStream() bool {
return c.Accepts("text/event-stream") != ""
}
// Cookies are used for getting a cookie value by key.
// Defaults to the empty string "" if the cookie doesn't exist.
// If a default value is given, it will return that value if the cookie doesn't exist.
// The returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting to use the value outside the Handler.
func (r *DefaultReq) Cookies(key string, defaultValue ...string) string {
return defaultString(r.c.app.toString(r.c.fasthttp.Request.Header.Cookie(key)), defaultValue)
}
// Request return the *fasthttp.Request object
// This allows you to use all fasthttp request methods
// https://godoc.org/github.com/valyala/fasthttp#Request
func (r *DefaultReq) Request() *fasthttp.Request {
return &r.c.fasthttp.Request
}
// FormFile returns the first file by key from a MultipartForm.
func (r *DefaultReq) FormFile(key string) (*multipart.FileHeader, error) {
return r.c.fasthttp.FormFile(key)
}
// FormValue returns the first value by key from a MultipartForm.
// Search is performed in QueryArgs, PostArgs, MultipartForm and FormFile in this particular order.
// Defaults to the empty string "" if the form value doesn't exist.
// If a default value is given, it will return that value if the form value does not exist.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
func (r *DefaultReq) FormValue(key string, defaultValue ...string) string {
return defaultString(r.c.app.toString(r.c.fasthttp.FormValue(key)), defaultValue)
}
// Fresh returns true when the response is still “fresh” in the client's cache,
// otherwise false is returned to indicate that the client cache is now stale
// and the full response should be sent.
// When a client sends the Cache-Control: no-cache request header to indicate an end-to-end
// reload request, this module will return false to make handling these requests transparent.
// https://github.com/jshttp/fresh/blob/master/index.js#L33
func (r *DefaultReq) Fresh() bool {
header := &r.c.fasthttp.Request.Header
// fields
modifiedSince := header.Peek(HeaderIfModifiedSince)
noneMatch := header.Peek(HeaderIfNoneMatch)
// unconditional request
if len(modifiedSince) == 0 && len(noneMatch) == 0 {
return false
}
// Always return stale when Cache-Control: no-cache
// to support end-to-end reload requests
// https://www.rfc-editor.org/rfc/rfc9111#section-5.2.1.4
cacheControl := header.Peek(HeaderCacheControl)
if len(cacheControl) > 0 && isNoCache(utils.UnsafeString(cacheControl)) {
return false
}
// if-none-match
if len(noneMatch) > 0 && (len(noneMatch) != 1 || noneMatch[0] != '*') {
app := r.c.app
response := &r.c.fasthttp.Response
etag := app.toString(response.Header.Peek(HeaderETag))
if etag == "" {
return false
}
if app.isEtagStale(etag, noneMatch) {
return false
}
if len(modifiedSince) > 0 {
lastModified := response.Header.Peek(HeaderLastModified)
if len(lastModified) > 0 {
lastModifiedTime, err := fasthttp.ParseHTTPDate(lastModified)
if err != nil {
return false
}
modifiedSinceTime, err := fasthttp.ParseHTTPDate(modifiedSince)
if err != nil {
return false
}
return lastModifiedTime.Compare(modifiedSinceTime) != 1
}
}
}
return true
}
// Get returns the HTTP request header specified by field.
// Field names are case-insensitive
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
func (r *DefaultReq) Get(key string, defaultValue ...string) string {
return GetReqHeader(r.c, key, defaultValue...)
}
// GetReqHeader returns the HTTP request header specified by filed.
// This function is generic and can handle different headers type values.
// If the generic type cannot be matched to a supported type, the function
// returns the default value (if provided) or the zero value of type V.
func GetReqHeader[V GenericType](c Ctx, key string, defaultValue ...V) V {
v, err := genericParseType[V](c.App().toString(c.Request().Header.Peek(key)))
if err != nil && len(defaultValue) > 0 {
return defaultValue[0]
}
return v
}
// GetHeaders (a.k.a GetReqHeaders) returns the HTTP request headers.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
func (r *DefaultReq) GetHeaders() map[string][]string {
app := r.c.app
reqHeader := &r.c.fasthttp.Request.Header
// Pre-allocate map with known header count to avoid reallocations
headers := make(map[string][]string, reqHeader.Len())
for k, v := range reqHeader.All() {
key := app.toString(k)
headers[key] = append(headers[key], app.toString(v))
}
return headers
}
// Host contains the host derived from the X-Forwarded-Host or Host HTTP header.
// Returned value is only valid within the handler. Do not store any references.
// In a network context, `Host` refers to the combination of a hostname and potentially a port number used for connecting,
// while `Hostname` refers specifically to the name assigned to a device on a network, excluding any port information.
// Example: URL: https://example.com:8080 -> Host: example.com:8080
// Make copies or use the Immutable setting instead.
// Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy.
func (r *DefaultReq) Host() string {
if r.IsProxyTrusted() {
if host := r.Get(HeaderXForwardedHost); host != "" {
if before, _, found := strings.Cut(host, ","); found {
return utils.TrimSpace(before)
}
return utils.TrimSpace(host)
}
}
return r.c.app.toString(r.c.fasthttp.Request.URI().Host())
}
// Hostname contains the hostname derived from the X-Forwarded-Host or Host HTTP header using the c.Host() method.
// Returned value is only valid within the handler. Do not store any references.
// Example: URL: https://example.com:8080 -> Hostname: example.com
// Make copies or use the Immutable setting instead.
// Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy.
func (r *DefaultReq) Hostname() string {
addr, _ := parseAddr(r.Host())
return addr
}
// Port returns the remote port of the request.
func (r *DefaultReq) Port() string {
addr := r.c.fasthttp.RemoteAddr()
if addr == nil {
return "0"
}
switch typedAddr := addr.(type) {
case *net.TCPAddr:
return strconv.Itoa(typedAddr.Port)
case *net.UnixAddr:
return ""
}
_, port, err := net.SplitHostPort(addr.String())
if err != nil {
return ""
}
return port
}
// IP returns the remote IP address of the request.
// If ProxyHeader and IP Validation is configured, it will parse that header and return the first valid IP address.
// Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy.
func (r *DefaultReq) IP() string {
app := r.c.app
if r.IsProxyTrusted() && app.config.ProxyHeader != "" {
return r.extractIPFromHeader(app.config.ProxyHeader)
}
if ip := r.c.fasthttp.RemoteIP(); ip != nil {
return ip.String()
}
return ""
}
// extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear.
// When IP validation is enabled, any invalid IPs will be omitted.
func (r *DefaultReq) extractIPsFromHeader(header string) []string {
// TODO: Reuse the c.extractIPFromHeader func somehow in here
headerValue := r.Get(header)
// We can't know how many IPs we will return, but we will try to guess with this constant division.
// Counting ',' makes function slower for about 50ns in general case.
const maxEstimatedCount = 8
estimatedCount := min(len(headerValue)/maxEstimatedCount,
// Avoid big allocation on big header
maxEstimatedCount)
ipsFound := make([]string, 0, estimatedCount)
i := 0
j := -1
for {
var v4, v6 bool
// Manually splitting string without allocating slice, working with parts directly
i, j = j+1, j+2
if j > len(headerValue) {
break
}
for j < len(headerValue) && headerValue[j] != ',' {
switch headerValue[j] {
case ':':
v6 = true
case '.':
v4 = true
default:
// do nothing
}
j++
}
for i < j && (headerValue[i] == ' ' || headerValue[i] == ',') {
i++
}
s := utils.TrimRight(headerValue[i:j], ' ')
if r.c.app.config.EnableIPValidation {
// Skip validation if IP is clearly not IPv4/IPv6; otherwise, validate without allocations
if (!v6 && !v4) || (v6 && !utils.IsIPv6(s)) || (v4 && !utils.IsIPv4(s)) {
continue
}
}
ipsFound = append(ipsFound, s)
}
return ipsFound
}
// extractIPFromHeader will attempt to pull the real client IP from the given header when IP validation is enabled.
// currently, it will return the first valid IP address in header.
// when IP validation is disabled, it will simply return the value of the header without any inspection.
// Implementation is almost the same as in extractIPsFromHeader, but without allocation of []string.
func (r *DefaultReq) extractIPFromHeader(header string) string {
app := r.c.app
if app.config.EnableIPValidation {
headerValue := r.Get(header)
i := 0
j := -1
for {
var v4, v6 bool
// Manually splitting string without allocating slice, working with parts directly
i, j = j+1, j+2
if j > len(headerValue) {
break
}
for j < len(headerValue) && headerValue[j] != ',' {
switch headerValue[j] {
case ':':
v6 = true
case '.':
v4 = true
default:
// do nothing
}
j++
}
for i < j && headerValue[i] == ' ' {
i++
}
s := utils.TrimRight(headerValue[i:j], ' ')
if app.config.EnableIPValidation {
if (!v6 && !v4) || (v6 && !utils.IsIPv6(s)) || (v4 && !utils.IsIPv4(s)) {
continue
}
}
return s
}
if ip := r.c.fasthttp.RemoteIP(); ip != nil {
return ip.String()
}
return ""
}
// default behavior if IP validation is not enabled is just to return whatever value is
// in the proxy header. Even if it is empty or invalid
return r.Get(app.config.ProxyHeader)
}
// IPs returns a string slice of IP addresses specified in the X-Forwarded-For request header.
// When IP validation is enabled, only valid IPs are returned.
func (r *DefaultReq) IPs() []string {
return r.extractIPsFromHeader(HeaderXForwardedFor)
}
// Is returns the matching content type,
// if the incoming request's Content-Type HTTP header field matches the MIME type specified by the type parameter
func (r *DefaultReq) Is(extension string) bool {
extensionHeader := utils.GetMIME(extension)
if extensionHeader == "" {
return false
}
ct := r.c.app.toString(r.c.fasthttp.Request.Header.ContentType())
if i := strings.IndexByte(ct, ';'); i != -1 {
ct = ct[:i]
}
ct = utils.TrimSpace(ct)
return utils.EqualFold(ct, extensionHeader)
}
// Locals makes it possible to pass any values under keys scoped to the request
// and therefore available to all following routes that match the request.
//
// All the values are removed from ctx after returning from the top
// RequestHandler. Additionally, Close method is called on each value
// implementing io.Closer before removing the value from ctx.
func (r *DefaultReq) Locals(key any, value ...any) any {
if len(value) == 0 {
return r.c.fasthttp.UserValue(key)
}
r.c.fasthttp.SetUserValue(key, value[0])
return value[0]
}
// Locals function utilizing Go's generics feature.
// This function allows for manipulating and retrieving local values within a
// request context with a more specific data type.
//
// All the values are removed from ctx after returning from the top
// RequestHandler. Additionally, Close method is called on each value
// implementing io.Closer before removing the value from ctx.
func Locals[V any](c Ctx, key any, value ...V) V {
var v V
var ok bool
if len(value) == 0 {
v, ok = c.Locals(key).(V)
} else {
v, ok = c.Locals(key, value[0]).(V)
}
if !ok {
return v // return zero of type V
}
return v
}
// Method returns the HTTP request method for the context, optionally overridden by the provided argument.
// If no override is given or if the provided override is not a valid HTTP method, it returns the current method from the context.
// Otherwise, it updates the context's method and returns the overridden method as a string.
func (r *DefaultReq) Method(override ...string) string {
app := r.c.app
if len(override) == 0 {
// Nothing to override, just return current method from context
return app.method(r.c.methodInt)
}
method := utilsstrings.ToUpper(override[0])
methodInt := app.methodInt(method)
if methodInt == -1 {
// Provided override does not valid HTTP method, no override, return current method
return app.method(r.c.methodInt)
}
r.c.methodInt = methodInt
return method
}
// MultipartForm parse form entries from binary.
// This returns a map[string][]string, so given a key, the value will be a string slice.
func (r *DefaultReq) MultipartForm() (*multipart.Form, error) {
return r.c.fasthttp.MultipartForm()
}
// OriginalURL contains the original request URL.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting to use the value outside the Handler.
func (r *DefaultReq) OriginalURL() string {
return r.c.app.toString(r.c.fasthttp.Request.Header.RequestURI())
}
// Params is used to get the route parameters.
// Defaults to empty string "" if the param doesn't exist.
// If a default value is given, it will return that value if the param doesn't exist.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting to use the value outside the Handler.
func (r *DefaultReq) Params(key string, defaultValue ...string) string {
if key == "*" || key == "+" {
key += "1"
}
app := r.c.app
route := r.c.Route()
values := &r.c.values
for i := range route.Params {
if len(key) != len(route.Params[i]) {
continue
}
if route.Params[i] == key || (!app.config.CaseSensitive && utils.EqualFold(route.Params[i], key)) {
// if there is no value for the key
if len(values) <= i || values[i] == "" {
break
}
val := values[i]
return r.c.app.GetString(val)
}
}
return defaultString("", defaultValue)
}
// Params is used to get the route parameters.
// This function is generic and can handle different route parameters type values.
// If the generic type cannot be matched to a supported type, the function
// returns the default value (if provided) or the zero value of type V.
//
// Example:
//
// http://example.com/user/:user -> http://example.com/user/john
// Params[string](c, "user") -> returns john
//
// http://example.com/id/:id -> http://example.com/user/114
// Params[int](c, "id") -> returns 114 as integer.
//
// http://example.com/id/:number -> http://example.com/id/john
// Params[int](c, "number", 0) -> returns 0 because can't parse 'john' as integer.
func Params[V GenericType](c Ctx, key string, defaultValue ...V) V {
v, err := genericParseType[V](c.Params(key))
if err != nil && len(defaultValue) > 0 {
return defaultValue[0]
}
return v
}
// Scheme contains the request protocol string: http or https for TLS requests.
// Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy.
func (r *DefaultReq) Scheme() string {
ctx := r.c.fasthttp
if ctx.IsTLS() {
return schemeHTTPS
}
if !r.IsProxyTrusted() {
return schemeHTTP
}
app := r.c.app
scheme := schemeHTTP
const lenXHeaderName = 12
for key, val := range ctx.Request.Header.All() {
if len(key) < lenXHeaderName {
continue // Neither "X-Forwarded-" nor "X-Url-Scheme"
}
switch {
case utils.EqualFold(key[:len(xForwardedPrefix)], xForwardedPrefix):
if utils.EqualFold(key, xForwardedProtoBytes) ||
utils.EqualFold(key, xForwardedProtocolBytes) {
v := app.toString(val)
if before, _, found := strings.Cut(v, ","); found {
scheme = utils.TrimSpace(before)
} else {
scheme = utils.TrimSpace(v)
}
} else if utils.EqualFold(key, xForwardedSslBytes) && utils.EqualFold(val, onBytes) {
scheme = schemeHTTPS
}
case utils.EqualFold(key, xURLSchemeBytes):
scheme = utils.TrimSpace(app.toString(val))
default:
continue
}
}
return utilsstrings.ToLower(utils.TrimSpace(scheme))
}
// Protocol returns the HTTP protocol of request: HTTP/1.1 and HTTP/2.
func (r *DefaultReq) Protocol() string {
return r.c.app.toString(r.c.fasthttp.Request.Header.Protocol())
}
// Query returns the query string parameter in the url.
// Defaults to empty string "" if the query doesn't exist.
// If a default value is given, it will return that value if the query doesn't exist.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting to use the value outside the Handler.
func (r *DefaultReq) Query(key string, defaultValue ...string) string {
return Query(r.c, key, defaultValue...)
}
// Queries returns a map of query parameters and their values.
//
// GET /?name=alex&wanna_cake=2&id=
// Queries()["name"] == "alex"
// Queries()["wanna_cake"] == "2"
// Queries()["id"] == ""
//
// GET /?field1=value1&field1=value2&field2=value3
// Queries()["field1"] == "value2"
// Queries()["field2"] == "value3"
//
// GET /?list_a=1&list_a=2&list_a=3&list_b[]=1&list_b[]=2&list_b[]=3&list_c=1,2,3
// Queries()["list_a"] == "3"
// Queries()["list_b[]"] == "3"
// Queries()["list_c"] == "1,2,3"
//
// GET /api/search?filters.author.name=John&filters.category.name=Technology&filters[customer][name]=Alice&filters[status]=pending
// Queries()["filters.author.name"] == "John"
// Queries()["filters.category.name"] == "Technology"
// Queries()["filters[customer][name]"] == "Alice"
// Queries()["filters[status]"] == "pending"
func (r *DefaultReq) Queries() map[string]string {
app := r.c.app
queryArgs := r.c.fasthttp.QueryArgs()
m := make(map[string]string, queryArgs.Len())
for key, value := range queryArgs.All() {
m[app.toString(key)] = app.toString(value)
}
return m
}
// Query Retrieves the value of a query parameter from the request's URI.
// The function is generic and can handle query parameter values of different types.
// It takes the following parameters:
// - c: The context object representing the current request.
// - key: The name of the query parameter.
// - defaultValue: (Optional) The default value to return if the query parameter is not found or cannot be parsed.
// The function performs the following steps:
// 1. Type-asserts the context object to *DefaultCtx.
// 2. Retrieves the raw query parameter value from the request's URI.
// 3. Parses the raw value into the appropriate type based on the generic type parameter V.
// If parsing fails, the function checks if a default value is provided. If so, it returns the default value.
// 4. Returns the parsed value.
//
// If the generic type cannot be matched to a supported type, the function returns the default value (if provided) or the zero value of type V.
//
// Example usage:
//
// GET /?search=john&age=8
// name := Query[string](c, "search") // Returns "john"
// age := Query[int](c, "age") // Returns 8
// unknown := Query[string](c, "unknown", "default") // Returns "default" since the query parameter "unknown" is not found
func Query[V GenericType](c Ctx, key string, defaultValue ...V) V {
q := c.App().toString(c.RequestCtx().QueryArgs().Peek(key))
v, err := genericParseType[V](q)
if err != nil && len(defaultValue) > 0 {
return defaultValue[0]
}
return v
}
// Range returns a struct containing the type and a slice of ranges.
func (r *DefaultReq) Range(size int64) (Range, error) {
var (
rangeData Range
ranges string
)
rangeStr := utils.TrimSpace(r.Get(HeaderRange))
maxRanges := r.c.app.config.MaxRanges
const maxRangePrealloc = 8
prealloc := min(maxRanges, maxRangePrealloc)
if prealloc > 0 {
rangeData.Ranges = make([]RangeSet, 0, prealloc)
}
parseBound := func(value string) (int64, error) {
parsed, err := utils.ParseUint(value)
if err != nil {
return 0, fmt.Errorf("parse range bound %q: %w", value, err)
}
if parsed > (math.MaxUint64 >> 1) {
return 0, ErrRangeMalformed
}
return int64(parsed), nil
}
before, after, found := strings.Cut(rangeStr, "=")
if !found || strings.IndexByte(after, '=') >= 0 {
return rangeData, ErrRangeMalformed
}
rangeData.Type = utilsstrings.ToLower(utils.TrimSpace(before))
if rangeData.Type != "bytes" {
return rangeData, ErrRangeMalformed
}
ranges = utils.TrimSpace(after)
var (
singleRange string
moreRanges = ranges
rangeCount int
)
for moreRanges != "" {
rangeCount++
if rangeCount > maxRanges {
r.c.DefaultRes.Status(StatusRequestedRangeNotSatisfiable)
r.c.DefaultRes.Set(HeaderContentRange, "bytes */"+utils.FormatInt(size)) //nolint:staticcheck // It is fine to ignore the static check
return rangeData, ErrRangeTooLarge
}
singleRange = moreRanges
if i := strings.IndexByte(moreRanges, ','); i >= 0 {
singleRange = moreRanges[:i]
moreRanges = utils.TrimSpace(moreRanges[i+1:])
} else {
moreRanges = ""
}
singleRange = utils.TrimSpace(singleRange)
var (
startStr, endStr string
i int
)
if i = strings.IndexByte(singleRange, '-'); i == -1 {
return rangeData, ErrRangeMalformed
}
startStr = utils.TrimSpace(singleRange[:i])
endStr = utils.TrimSpace(singleRange[i+1:])
start, startErr := parseBound(startStr)
end, endErr := parseBound(endStr)
if errors.Is(startErr, ErrRangeMalformed) || errors.Is(endErr, ErrRangeMalformed) {
return rangeData, ErrRangeMalformed
}
if startErr != nil { // -nnn
start = max(size-end, 0)
end = size - 1
} else if endErr != nil { // nnn-
end = size - 1
}
if end > size-1 { // limit last-byte-pos to current length
end = size - 1
}
if start > end || start < 0 {
continue
}
rangeData.Ranges = append(rangeData.Ranges, RangeSet{
Start: start,
End: end,
})
}
if len(rangeData.Ranges) < 1 {
r.c.DefaultRes.Status(StatusRequestedRangeNotSatisfiable)
r.c.DefaultRes.Set(HeaderContentRange, "bytes */"+utils.FormatInt(size)) //nolint:staticcheck // It is fine to ignore the static check
return rangeData, ErrRequestedRangeNotSatisfiable
}
return rangeData, nil
}
// Route returns the matched Route struct.
func (r *DefaultReq) Route() *Route {
return r.c.Route()
}
// Subdomains returns a slice of subdomains from the host, excluding the last `offset` components.
// If the offset is negative or exceeds the number of subdomains, an empty slice is returned.
// If the offset is zero every label (no trimming) is returned.
func (r *DefaultReq) Subdomains(offset ...int) []string {
o := 2
if len(offset) > 0 {
o = offset[0]
}
// Negative offset, return nothing.
if o < 0 {
return []string{}
}
// Normalize host according to RFC 3986
host := r.Hostname()
// Trim the trailing dot of a fully-qualified domain
if strings.HasSuffix(host, ".") {
host = utils.TrimRight(host, '.')
}
host = utilsstrings.ToLower(host)
// Decode punycode labels only when necessary
if strings.Contains(host, "xn--") {
if u, err := idna.Lookup.ToUnicode(host); err == nil {
host = utilsstrings.ToLower(u)
}
}
// Return nothing for IP addresses
ip := host
if strings.HasPrefix(ip, "[") && strings.HasSuffix(ip, "]") {
ip = ip[1 : len(ip)-1]
}
if utils.IsIPv4(ip) || utils.IsIPv6(ip) {
return []string{}
}
// Use stack-allocated array for typical domain names (up to 8 labels)
// This avoids heap allocation for most common cases
var partsBuf [8]string
parts := partsBuf[:0]
for part := range strings.SplitSeq(host, ".") {
parts = append(parts, part)
}
// offset == 0, caller wants everything.
if o == 0 {
// Need to return a copy since partsBuf is on the stack
result := make([]string, len(parts))
copy(result, parts)
return result
}
// If we trim away the whole slice (or more), nothing remains.
if o >= len(parts) {
return []string{}
}
// Return a heap-allocated copy of the relevant portion
result := make([]string, len(parts)-o)
copy(result, parts[:len(parts)-o])
return result
}
// Stale returns the inverse of Fresh, indicating if the client's cached response is considered stale.
func (r *DefaultReq) Stale() bool {
return !r.Fresh()
}
// IsProxyTrusted checks trustworthiness of remote ip.
// If Config.TrustProxy false, it returns false.
// IsProxyTrusted can check remote ip by proxy ranges and ip map.
func (r *DefaultReq) IsProxyTrusted() bool {
config := r.c.app.config
if !config.TrustProxy {
return false
}
remoteAddr := r.c.fasthttp.RemoteAddr()
switch remoteAddr.(type) {
case *net.UnixAddr:
return config.TrustProxyConfig.UnixSocket
case *net.TCPAddr, *net.UDPAddr:
// Keep existing RemoteIP/IP-map/CIDR checks for TCP/UDP paths as-is.
default:
// Unknown address type: do not trust by default.
return false
}
ip := r.c.fasthttp.RemoteIP()
if ip == nil {
return false
}
if (config.TrustProxyConfig.Loopback && ip.IsLoopback()) ||
(config.TrustProxyConfig.Private && ip.IsPrivate()) ||
(config.TrustProxyConfig.LinkLocal && ip.IsLinkLocalUnicast()) {
return true
}
if _, trusted := config.TrustProxyConfig.ips[ip.String()]; trusted {
return true
}
for _, ipNet := range config.TrustProxyConfig.ranges {
if ipNet.Contains(ip) {
return true
}
}
return false
}
// IsFromLocal will return true if request came from local.
func (r *DefaultReq) IsFromLocal() bool {
// Unix sockets are inherently local - only processes on the same host can connect.
remoteAddr := r.c.fasthttp.RemoteAddr()
if _, ok := remoteAddr.(*net.UnixAddr); ok {
return true
}
if ip := r.c.fasthttp.RemoteIP(); ip != nil {
return ip.IsLoopback()
}
return false
}
// Release is a method to reset Req fields when to use ReleaseCtx()
func (r *DefaultReq) release() {
r.c = nil
}
func (r *DefaultReq) getBody() []byte {
return r.c.app.GetBytes(r.c.fasthttp.Request.Body())
}
================================================
FILE: req_interface_gen.go
================================================
// Code generated by ifacemaker; DO NOT EDIT.
package fiber
import (
"mime/multipart"
"github.com/valyala/fasthttp"
)
// Req is an interface for request-related Ctx methods.
type Req interface {
// Accepts checks if the specified extensions or content types are acceptable.
Accepts(offers ...string) string
// AcceptsCharsets checks if the specified charset is acceptable.
AcceptsCharsets(offers ...string) string
// AcceptsEncodings checks if the specified encoding is acceptable.
AcceptsEncodings(offers ...string) string
// AcceptsLanguages checks if the specified language is acceptable using
// RFC 4647 Basic Filtering.
AcceptsLanguages(offers ...string) string
// AcceptsLanguagesExtended checks if the specified language is acceptable using
// RFC 4647 Extended Filtering.
AcceptsLanguagesExtended(offers ...string) string
// App returns the *App reference to the instance of the Fiber application
App() *App
// BaseURL returns (protocol + host + base path).
BaseURL() string
// BodyRaw contains the raw body submitted in a POST request.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
BodyRaw() []byte
//nolint:nonamedreturns // gocritic unnamedResult prefers naming decoded body, decode count, and error
tryDecodeBodyInOrder(originalBody *[]byte, encodings []string) (body []byte, decodesRealized uint8, err error)
// Body contains the raw body submitted in a POST request.
// This method will decompress the body if the 'Content-Encoding' header is provided.
// It returns the original (or decompressed) body data which is valid only within the handler.
// Don't store direct references to the returned data.
// If you need to keep the body's data later, make a copy or use the Immutable option.
Body() []byte
// RequestCtx returns *fasthttp.RequestCtx that carries a deadline
// a cancellation signal, and other values across API boundaries.
RequestCtx() *fasthttp.RequestCtx
// Cookies are used for getting a cookie value by key.
// Defaults to the empty string "" if the cookie doesn't exist.
// If a default value is given, it will return that value if the cookie doesn't exist.
// The returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting to use the value outside the Handler.
Cookies(key string, defaultValue ...string) string
// Request return the *fasthttp.Request object
// This allows you to use all fasthttp request methods
// https://godoc.org/github.com/valyala/fasthttp#Request
Request() *fasthttp.Request
// FormFile returns the first file by key from a MultipartForm.
FormFile(key string) (*multipart.FileHeader, error)
// FormValue returns the first value by key from a MultipartForm.
// Search is performed in QueryArgs, PostArgs, MultipartForm and FormFile in this particular order.
// Defaults to the empty string "" if the form value doesn't exist.
// If a default value is given, it will return that value if the form value does not exist.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
FormValue(key string, defaultValue ...string) string
// Fresh returns true when the response is still “fresh” in the client's cache,
// otherwise false is returned to indicate that the client cache is now stale
// and the full response should be sent.
// When a client sends the Cache-Control: no-cache request header to indicate an end-to-end
// reload request, this module will return false to make handling these requests transparent.
// https://github.com/jshttp/fresh/blob/master/index.js#L33
Fresh() bool
// Get returns the HTTP request header specified by field.
// Field names are case-insensitive
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
Get(key string, defaultValue ...string) string
// GetHeaders (a.k.a GetReqHeaders) returns the HTTP request headers.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
GetHeaders() map[string][]string
// Host contains the host derived from the X-Forwarded-Host or Host HTTP header.
// Returned value is only valid within the handler. Do not store any references.
// In a network context, `Host` refers to the combination of a hostname and potentially a port number used for connecting,
// while `Hostname` refers specifically to the name assigned to a device on a network, excluding any port information.
// Example: URL: https://example.com:8080 -> Host: example.com:8080
// Make copies or use the Immutable setting instead.
// Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy.
Host() string
// Hostname contains the hostname derived from the X-Forwarded-Host or Host HTTP header using the c.Host() method.
// Returned value is only valid within the handler. Do not store any references.
// Example: URL: https://example.com:8080 -> Hostname: example.com
// Make copies or use the Immutable setting instead.
// Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy.
Hostname() string
// Port returns the remote port of the request.
Port() string
// IP returns the remote IP address of the request.
// If ProxyHeader and IP Validation is configured, it will parse that header and return the first valid IP address.
// Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy.
IP() string
// extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear.
// When IP validation is enabled, any invalid IPs will be omitted.
extractIPsFromHeader(header string) []string
// extractIPFromHeader will attempt to pull the real client IP from the given header when IP validation is enabled.
// currently, it will return the first valid IP address in header.
// when IP validation is disabled, it will simply return the value of the header without any inspection.
// Implementation is almost the same as in extractIPsFromHeader, but without allocation of []string.
extractIPFromHeader(header string) string
// IPs returns a string slice of IP addresses specified in the X-Forwarded-For request header.
// When IP validation is enabled, only valid IPs are returned.
IPs() []string
// Is returns the matching content type,
// if the incoming request's Content-Type HTTP header field matches the MIME type specified by the type parameter
Is(extension string) bool
// Locals makes it possible to pass any values under keys scoped to the request
// and therefore available to all following routes that match the request.
//
// All the values are removed from ctx after returning from the top
// RequestHandler. Additionally, Close method is called on each value
// implementing io.Closer before removing the value from ctx.
Locals(key any, value ...any) any
// Method returns the HTTP request method for the context, optionally overridden by the provided argument.
// If no override is given or if the provided override is not a valid HTTP method, it returns the current method from the context.
// Otherwise, it updates the context's method and returns the overridden method as a string.
Method(override ...string) string
// MultipartForm parse form entries from binary.
// This returns a map[string][]string, so given a key, the value will be a string slice.
MultipartForm() (*multipart.Form, error)
// OriginalURL contains the original request URL.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting to use the value outside the Handler.
OriginalURL() string
// Params is used to get the route parameters.
// Defaults to empty string "" if the param doesn't exist.
// If a default value is given, it will return that value if the param doesn't exist.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting to use the value outside the Handler.
Params(key string, defaultValue ...string) string
// Scheme contains the request protocol string: http or https for TLS requests.
// Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy.
Scheme() string
// Protocol returns the HTTP protocol of request: HTTP/1.1 and HTTP/2.
Protocol() string
// Query returns the query string parameter in the url.
// Defaults to empty string "" if the query doesn't exist.
// If a default value is given, it will return that value if the query doesn't exist.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting to use the value outside the Handler.
Query(key string, defaultValue ...string) string
// Queries returns a map of query parameters and their values.
//
// GET /?name=alex&wanna_cake=2&id=
// Queries()["name"] == "alex"
// Queries()["wanna_cake"] == "2"
// Queries()["id"] == ""
//
// GET /?field1=value1&field1=value2&field2=value3
// Queries()["field1"] == "value2"
// Queries()["field2"] == "value3"
//
// GET /?list_a=1&list_a=2&list_a=3&list_b[]=1&list_b[]=2&list_b[]=3&list_c=1,2,3
// Queries()["list_a"] == "3"
// Queries()["list_b[]"] == "3"
// Queries()["list_c"] == "1,2,3"
//
// GET /api/search?filters.author.name=John&filters.category.name=Technology&filters[customer][name]=Alice&filters[status]=pending
// Queries()["filters.author.name"] == "John"
// Queries()["filters.category.name"] == "Technology"
// Queries()["filters[customer][name]"] == "Alice"
// Queries()["filters[status]"] == "pending"
Queries() map[string]string
// Range returns a struct containing the type and a slice of ranges.
Range(size int64) (Range, error)
// Route returns the matched Route struct.
Route() *Route
// Subdomains returns a slice of subdomains from the host, excluding the last `offset` components.
// If the offset is negative or exceeds the number of subdomains, an empty slice is returned.
// If the offset is zero every label (no trimming) is returned.
Subdomains(offset ...int) []string
// Stale returns the inverse of Fresh, indicating if the client's cached response is considered stale.
Stale() bool
// IsProxyTrusted checks trustworthiness of remote ip.
// If Config.TrustProxy false, it returns false.
// IsProxyTrusted can check remote ip by proxy ranges and ip map.
IsProxyTrusted() bool
// IsFromLocal will return true if request came from local.
IsFromLocal() bool
// Release is a method to reset Req fields when to use ReleaseCtx()
release()
getBody() []byte
}
================================================
FILE: res.go
================================================
package fiber
import (
"bufio"
"bytes"
"fmt"
"html/template"
"io"
"io/fs"
"net/http"
"net/url"
"os"
pathpkg "path"
"path/filepath"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
"github.com/gofiber/utils/v2"
"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)
// SendFile defines configuration options when to transfer file with SendFile.
type SendFile struct {
// FS is the file system to serve the static files from.
// You can use interfaces compatible with fs.FS like embed.FS, os.DirFS etc.
//
// Optional. Default: nil
FS fs.FS
// When set to true, the server tries minimizing CPU usage by caching compressed files.
// This works differently than the github.com/gofiber/compression middleware.
// You have to set Content-Encoding header to compress the file.
// Available compression methods are gzip, br, and zstd.
//
// Optional. Default: false
Compress bool `json:"compress"`
// When set to true, enables byte range requests.
//
// Optional. Default: false
ByteRange bool `json:"byte_range"`
// When set to true, enables direct download.
//
// Optional. Default: false
Download bool `json:"download"`
// Expiration duration for inactive file handlers.
// Use a negative time.Duration to disable it.
//
// Optional. Default: 10 * time.Second
CacheDuration time.Duration `json:"cache_duration"`
// The value for the Cache-Control HTTP-header
// that is set on the file response. MaxAge is defined in seconds.
//
// Optional. Default: 0
MaxAge int `json:"max_age"`
}
// sendFileStore is used to keep the SendFile configuration and the handler.
type sendFileStore struct {
handler fasthttp.RequestHandler
cacheControlValue string
config SendFile
}
// configEqual compares the current SendFile config with the new one
// and returns true if they are equal.
//
// Here we don't use reflect.DeepEqual because it is quite slow compared to manual comparison.
func (sf *sendFileStore) configEqual(cfg SendFile) bool {
if sf.config.FS != cfg.FS {
return false
}
if sf.config.Compress != cfg.Compress {
return false
}
if sf.config.ByteRange != cfg.ByteRange {
return false
}
if sf.config.Download != cfg.Download {
return false
}
if sf.config.CacheDuration != cfg.CacheDuration {
return false
}
if sf.config.MaxAge != cfg.MaxAge {
return false
}
return true
}
// Cookie defines the values used when configuring cookies emitted by
// DefaultRes.Cookie.
type Cookie struct {
Expires time.Time `json:"expires"` // The expiration date of the cookie
Name string `json:"name"` // The name of the cookie
Value string `json:"value"` // The value of the cookie
Path string `json:"path"` // Specifies a URL path which is allowed to receive the cookie
Domain string `json:"domain"` // Specifies the domain which is allowed to receive the cookie
SameSite string `json:"same_site"` // Controls whether or not a cookie is sent with cross-site requests
MaxAge int `json:"max_age"` // The maximum age (in seconds) of the cookie
Secure bool `json:"secure"` // Indicates that the cookie should only be transmitted over a secure HTTPS connection
HTTPOnly bool `json:"http_only"` // Indicates that the cookie is accessible only through the HTTP protocol
Partitioned bool `json:"partitioned"` // Indicates if the cookie is stored in a partitioned cookie jar
SessionOnly bool `json:"session_only"` // Indicates if the cookie is a session-only cookie
}
// ResFmt associates a Content Type to a fiber.Handler for c.Format
type ResFmt struct {
Handler func(Ctx) error
MediaType string
}
// DefaultRes is the default implementation of Res used by DefaultCtx.
//
//go:generate ifacemaker --file res.go --struct DefaultRes --iface Res --pkg fiber --output res_interface_gen.go --not-exported true --iface-comment "Res is an interface for response-related Ctx methods."
type DefaultRes struct {
c *DefaultCtx
}
// App returns the *App reference to the instance of the Fiber application
func (r *DefaultRes) App() *App {
return r.c.app
}
// Append the specified value to the HTTP response header field.
// If the header is not already set, it creates the header with the specified value.
func (r *DefaultRes) Append(field string, values ...string) {
if len(values) == 0 {
return
}
h := r.c.app.toString(r.c.fasthttp.Response.Header.Peek(field))
originalH := h
for _, value := range values {
if h == "" {
h = value
} else if !headerContainsValue(h, value) {
h += ", " + value
}
}
if originalH != h {
r.Set(field, h)
}
}
// headerContainsValue checks if a header value already contains the given value
// as a comma-separated element. Per RFC 9110, list elements are separated by commas
// with optional whitespace (OWS) around them.
func headerContainsValue(header, value string) bool {
// Empty value should never match
if value == "" {
return false
}
// Exact match (single value header)
if header == value {
return true
}
// Check each comma-separated element, handling optional whitespace (OWS)
for part := range strings.SplitSeq(header, ",") {
if utils.TrimSpace(part) == value {
return true
}
}
return false
}
func sanitizeFilename(filename string) string {
for _, r := range filename {
if unicode.IsControl(r) {
b := make([]byte, 0, len(filename))
for _, rr := range filename {
if !unicode.IsControl(rr) {
b = utf8.AppendRune(b, rr)
}
}
return utils.TrimSpace(string(b))
}
}
return utils.TrimSpace(filename)
}
func fallbackFilenameIfInvalid(filename string) string {
if filename == "" || filename == "." {
return "download"
}
return filename
}
// Attachment sets the HTTP response Content-Disposition header field to attachment.
func (r *DefaultRes) Attachment(filename ...string) {
if len(filename) > 0 {
fname := filepath.Base(filename[0])
fname = sanitizeFilename(fname)
fname = fallbackFilenameIfInvalid(fname)
r.Type(filepath.Ext(fname))
app := r.c.app
var quoted string
if app.isASCII(fname) {
quoted = app.quoteString(fname)
} else {
quoted = app.quoteRawString(fname)
}
disp := `attachment; filename="` + quoted + `"`
if !app.isASCII(fname) {
disp += `; filename*=UTF-8''` + url.PathEscape(fname)
}
r.setCanonical(HeaderContentDisposition, disp)
return
}
r.setCanonical(HeaderContentDisposition, "attachment")
}
// ClearCookie expires a specific cookie by key on the client side.
// If no key is provided it expires all cookies that came with the request.
func (r *DefaultRes) ClearCookie(key ...string) {
request := &r.c.fasthttp.Request
response := &r.c.fasthttp.Response
if len(key) > 0 {
for i := range key {
response.Header.DelClientCookie(key[i])
}
return
}
for k := range request.Header.Cookies() {
response.Header.DelClientCookieBytes(k)
}
}
// RequestCtx returns *fasthttp.RequestCtx that carries a deadline
// a cancellation signal, and other values across API boundaries.
func (r *DefaultRes) RequestCtx() *fasthttp.RequestCtx {
return r.c.fasthttp
}
// Cookie sets a cookie by passing a cookie struct.
func (r *DefaultRes) Cookie(cookie *Cookie) {
if cookie.Path == "" {
cookie.Path = "/"
}
if cookie.SessionOnly {
cookie.MaxAge = 0
cookie.Expires = time.Time{}
}
var sameSite http.SameSite
switch {
case utils.EqualFold(cookie.SameSite, CookieSameSiteStrictMode):
sameSite = http.SameSiteStrictMode
case utils.EqualFold(cookie.SameSite, CookieSameSiteNoneMode):
sameSite = http.SameSiteNoneMode
// SameSite=None requires Secure=true per RFC and browser requirements
cookie.Secure = true
case utils.EqualFold(cookie.SameSite, CookieSameSiteDisabled):
sameSite = 0
case utils.EqualFold(cookie.SameSite, CookieSameSiteLaxMode):
sameSite = http.SameSiteLaxMode
default:
sameSite = http.SameSiteLaxMode
}
// Partitioned requires Secure=true per CHIPS spec
if cookie.Partitioned {
cookie.Secure = true
}
// create/validate cookie using net/http
hc := &http.Cookie{
Name: cookie.Name,
Value: cookie.Value,
Path: cookie.Path,
Domain: cookie.Domain,
Expires: cookie.Expires,
MaxAge: cookie.MaxAge,
Secure: cookie.Secure,
HttpOnly: cookie.HTTPOnly,
SameSite: sameSite,
Partitioned: cookie.Partitioned,
}
if err := hc.Valid(); err != nil {
// invalid cookies are ignored, same approach as net/http
return
}
// create fasthttp cookie
fcookie := fasthttp.AcquireCookie()
fcookie.SetKey(hc.Name)
fcookie.SetValue(hc.Value)
fcookie.SetPath(hc.Path)
fcookie.SetDomain(hc.Domain)
if !cookie.SessionOnly {
fcookie.SetMaxAge(hc.MaxAge)
fcookie.SetExpire(hc.Expires)
}
fcookie.SetSecure(hc.Secure)
fcookie.SetHTTPOnly(hc.HttpOnly)
switch sameSite {
case http.SameSiteLaxMode:
fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode)
case http.SameSiteStrictMode:
fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode)
case http.SameSiteNoneMode:
fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode)
case http.SameSiteDefaultMode:
fcookie.SetSameSite(fasthttp.CookieSameSiteDefaultMode)
default:
fcookie.SetSameSite(fasthttp.CookieSameSiteDisabled)
}
fcookie.SetPartitioned(hc.Partitioned)
// Set resp header
r.c.fasthttp.Response.Header.SetCookie(fcookie)
fasthttp.ReleaseCookie(fcookie)
}
// Download transfers the file from path as an attachment.
// Typically, browsers will prompt the user for download.
// By default, the Content-Disposition header filename= parameter is the filepath (this typically appears in the browser dialog).
// Override this default with the filename parameter.
func (r *DefaultRes) Download(file string, filename ...string) error {
var fname string
if len(filename) > 0 {
fname = filepath.Base(filename[0])
} else {
fname = filepath.Base(file)
}
fname = sanitizeFilename(fname)
fname = fallbackFilenameIfInvalid(fname)
app := r.c.app
var quoted string
if app.isASCII(fname) {
quoted = app.quoteString(fname)
} else {
quoted = app.quoteRawString(fname)
}
disp := `attachment; filename="` + quoted + `"`
if !app.isASCII(fname) {
disp += `; filename*=UTF-8''` + url.PathEscape(fname)
}
r.setCanonical(HeaderContentDisposition, disp)
return r.SendFile(file)
}
// Response return the *fasthttp.Response object
// This allows you to use all fasthttp response methods
// https://godoc.org/github.com/valyala/fasthttp#Response
func (r *DefaultRes) Response() *fasthttp.Response {
return &r.c.fasthttp.Response
}
// Format performs content-negotiation on the Accept HTTP header.
// It uses Accepts to select a proper format and calls the matching
// user-provided handler function.
// If no accepted format is found, and a format with MediaType "default" is given,
// that default handler is called. If no format is found and no default is given,
// StatusNotAcceptable is sent.
func (r *DefaultRes) Format(handlers ...ResFmt) error {
if len(handlers) == 0 {
return ErrNoHandlers
}
for i, h := range handlers {
if h.Handler == nil {
return fmt.Errorf("format handler is nil for media type %q at index %d", h.MediaType, i)
}
}
r.Vary(HeaderAccept)
if r.c.DefaultReq.Get(HeaderAccept) == "" {
r.c.fasthttp.Response.Header.SetContentType(handlers[0].MediaType)
return handlers[0].Handler(r.c)
}
// Using an int literal as the slice capacity allows for the slice to be
// allocated on the stack. The number was chosen arbitrarily as an
// approximation of the maximum number of content types a user might handle.
// If the user goes over, it just causes allocations, so it's not a problem.
types := make([]string, 0, 8)
var defaultHandler Handler
for _, h := range handlers {
if h.MediaType == "default" {
defaultHandler = h.Handler
continue
}
types = append(types, h.MediaType)
}
accept := r.c.DefaultReq.Accepts(types...) //nolint:staticcheck // It is fine to ignore the static check
if accept == "" {
if defaultHandler == nil {
return r.SendStatus(StatusNotAcceptable)
}
return defaultHandler(r.c)
}
for _, h := range handlers {
if h.MediaType == accept {
r.c.fasthttp.Response.Header.SetContentType(h.MediaType)
return h.Handler(r.c)
}
}
return fmt.Errorf("%w: format: an Accept was found but no handler was called", errUnreachable)
}
// AutoFormat performs content-negotiation on the Accept HTTP header.
// It uses Accepts to select a proper format.
// The supported content types are text/html, text/plain, application/json, application/xml, application/vnd.msgpack, and application/cbor.
// For more flexible content negotiation, use Format.
// If the header is not specified or there is no proper format, text/plain is used.
func (r *DefaultRes) AutoFormat(body any) error {
// Get accepted content type
accept := r.c.DefaultReq.Accepts("html", "json", "txt", "xml", "msgpack", "cbor") //nolint:staticcheck // It is fine to ignore the static check
// Set accepted content type
r.Type(accept)
// Type convert provided body
var b string
switch val := body.(type) {
case string:
b = val
case []byte:
b = r.c.app.toString(val)
default:
b = fmt.Sprintf("%v", val)
}
// Format based on the accept content type
switch accept {
case "txt":
return r.SendString(b)
case "json":
return r.JSON(body)
case "xml":
return r.XML(body)
case "html":
return r.SendString("
" + b + "
")
case "msgpack":
return r.MsgPack(body)
case "cbor":
return r.CBOR(body)
}
// Default case
return r.SendString(b)
}
// Get (a.k.a. GetRespHeader) returns the HTTP response header specified by field.
// Field names are case-insensitive
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
func (r *DefaultRes) Get(key string, defaultValue ...string) string {
return defaultString(r.c.app.toString(r.c.fasthttp.Response.Header.Peek(key)), defaultValue)
}
// GetHeaders (a.k.a GetRespHeaders) returns the HTTP response headers.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
func (r *DefaultRes) GetHeaders() map[string][]string {
app := r.c.app
respHeader := &r.c.fasthttp.Response.Header
// Pre-allocate map with known header count to avoid reallocations
headers := make(map[string][]string, respHeader.Len())
for k, v := range respHeader.All() {
key := app.toString(k)
headers[key] = append(headers[key], app.toString(v))
}
return headers
}
// JSON converts any interface or string to JSON.
// Array and slice values encode as JSON arrays,
// except that []byte encodes as a base64-encoded string,
// and a nil slice encodes as the null JSON value.
// If the ctype parameter is given, this method will set the
// Content-Type header equal to ctype. If ctype is not given,
// The Content-Type header will be set to application/json; charset=utf-8.
func (r *DefaultRes) JSON(data any, ctype ...string) error {
raw, err := r.c.app.config.JSONEncoder(data)
if err != nil {
return err
}
response := &r.c.fasthttp.Response
response.SetBodyRaw(raw)
if len(ctype) > 0 {
response.Header.SetContentType(ctype[0])
} else {
response.Header.SetContentType(MIMEApplicationJSONCharsetUTF8)
}
return nil
}
// MsgPack converts any interface or string to MessagePack encoded bytes.
// If the ctype parameter is given, this method will set the
// Content-Type header equal to ctype. If ctype is not given,
// The Content-Type header will be set to application/vnd.msgpack.
func (r *DefaultRes) MsgPack(data any, ctype ...string) error {
raw, err := r.c.app.config.MsgPackEncoder(data)
if err != nil {
return err
}
response := &r.c.fasthttp.Response
response.SetBodyRaw(raw)
if len(ctype) > 0 {
response.Header.SetContentType(ctype[0])
} else {
response.Header.SetContentType(MIMEApplicationMsgPack)
}
return nil
}
// CBOR converts any interface or string to CBOR encoded bytes.
// If the ctype parameter is given, this method will set the
// Content-Type header equal to ctype. If ctype is not given,
// The Content-Type header will be set to application/cbor.
func (r *DefaultRes) CBOR(data any, ctype ...string) error {
raw, err := r.c.app.config.CBOREncoder(data)
if err != nil {
return err
}
response := &r.c.fasthttp.Response
response.SetBodyRaw(raw)
if len(ctype) > 0 {
response.Header.SetContentType(ctype[0])
} else {
response.Header.SetContentType(MIMEApplicationCBOR)
}
return nil
}
// JSONP sends a JSON response with JSONP support.
// This method is identical to JSON, except that it opts-in to JSONP callback support.
// By default, the callback name is simply callback.
func (r *DefaultRes) JSONP(data any, callback ...string) error {
raw, err := r.c.app.config.JSONEncoder(data)
if err != nil {
return err
}
cb := "callback"
if len(callback) > 0 {
cb = callback[0]
}
// Build JSONP response: callback(data);
// Use bytebufferpool to avoid string concatenation allocations
buf := bytebufferpool.Get()
buf.WriteString(cb)
buf.WriteByte('(')
buf.Write(raw)
buf.WriteString(");")
r.setCanonical(HeaderXContentTypeOptions, "nosniff")
r.c.fasthttp.Response.Header.SetContentType(MIMETextJavaScriptCharsetUTF8)
// Use SetBody (not SetBodyRaw) to copy the bytes before returning buffer to pool
r.c.fasthttp.Response.SetBody(buf.Bytes())
bytebufferpool.Put(buf)
return nil
}
// XML converts any interface or string to XML.
// This method also sets the content header to application/xml; charset=utf-8.
func (r *DefaultRes) XML(data any) error {
raw, err := r.c.app.config.XMLEncoder(data)
if err != nil {
return err
}
response := &r.c.fasthttp.Response
response.SetBodyRaw(raw)
response.Header.SetContentType(MIMEApplicationXMLCharsetUTF8)
return nil
}
// Links joins the links followed by the property to populate the response's Link HTTP header field.
func (r *DefaultRes) Links(link ...string) {
if len(link) == 0 {
return
}
bb := bytebufferpool.Get()
for i := range link {
if i%2 == 0 {
bb.WriteByte('<')
bb.WriteString(link[i])
bb.WriteByte('>')
} else {
bb.WriteString(`; rel="`)
bb.WriteString(link[i])
bb.WriteString(`",`)
}
}
r.setCanonical(HeaderLink, utils.TrimRight(r.c.app.toString(bb.Bytes()), ','))
bytebufferpool.Put(bb)
}
// Location sets the response Location HTTP header to the specified path parameter.
func (r *DefaultRes) Location(path string) {
r.setCanonical(HeaderLocation, path)
}
// OriginalURL contains the original request URL.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting to use the value outside the Handler.
func (r *DefaultRes) OriginalURL() string {
return r.c.OriginalURL()
}
// Redirect returns the Redirect reference.
// Use Redirect().Status() to set custom redirection status code.
// If status is not specified, status defaults to 303 See Other.
// You can use Redirect().To(), Redirect().Route() and Redirect().Back() for redirection.
func (r *DefaultRes) Redirect() *Redirect {
return r.c.Redirect()
}
// ViewBind Add vars to default view var map binding to template engine.
// Variables are read by the Render method and may be overwritten.
func (r *DefaultRes) ViewBind(vars Map) error {
return r.c.ViewBind(vars)
}
// getLocationFromRoute get URL location from route using parameters
func (r *DefaultRes) getLocationFromRoute(route *Route, params Map) (string, error) {
if route == nil || route.Path == "" {
return "", ErrNotFound
}
app := r.c.app
buf := bytebufferpool.Get()
for _, segment := range route.routeParser.segs {
if !segment.IsParam {
_, err := buf.WriteString(segment.Const)
if err != nil {
return "", fmt.Errorf("failed to write string: %w", err)
}
continue
}
for key, val := range params {
isSame := key == segment.ParamName || (!app.config.CaseSensitive && utils.EqualFold(key, segment.ParamName))
isGreedy := segment.IsGreedy && len(key) == 1 && bytes.IndexByte(greedyParameters, key[0]) >= 0
if isSame || isGreedy {
_, err := buf.WriteString(utils.ToString(val))
if err != nil {
return "", fmt.Errorf("failed to write string: %w", err)
}
}
}
}
location := buf.String()
// release buffer
bytebufferpool.Put(buf)
return location, nil
}
// GetRouteURL generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831"
func (r *DefaultRes) GetRouteURL(routeName string, params Map) (string, error) {
route := r.c.app.GetRoute(routeName)
return r.getLocationFromRoute(&route, params)
}
// Render a template with data and sends a text/html response.
// We support the following engines: https://github.com/gofiber/template
func (r *DefaultRes) Render(name string, bind any, layouts ...string) error {
// Get new buffer from pool
buf := bytebufferpool.Get()
defer bytebufferpool.Put(buf)
// Initialize empty bind map if bind is nil
if bind == nil {
bind = make(Map)
}
// Pass-locals-to-views, bind, appListKeys
r.c.renderExtensions(bind)
rootApp := r.c.app
var rendered bool
for i := len(rootApp.mountFields.appListKeys) - 1; i >= 0; i-- {
prefix := rootApp.mountFields.appListKeys[i]
app := rootApp.mountFields.appList[prefix]
if prefix == "" || strings.Contains(r.c.OriginalURL(), prefix) {
if len(layouts) == 0 && app.config.ViewsLayout != "" {
layouts = []string{
app.config.ViewsLayout,
}
}
// Render template from Views
if app.config.Views != nil {
if err := app.config.Views.Render(buf, name, bind, layouts...); err != nil {
return fmt.Errorf("failed to render: %w", err)
}
rendered = true
break
}
}
}
if !rendered {
// Render raw template using 'name' as filepath if no engine is set
var tmpl *template.Template
if _, err := readContent(buf, name); err != nil {
return err
}
// Parse template
tmpl, err := template.New("").Parse(rootApp.toString(buf.Bytes()))
if err != nil {
return fmt.Errorf("failed to parse: %w", err)
}
buf.Reset()
// Render template
if err := tmpl.Execute(buf, bind); err != nil {
return fmt.Errorf("failed to execute: %w", err)
}
}
response := &r.c.fasthttp.Response
// Set Content-Type to text/html
response.Header.SetContentType(MIMETextHTMLCharsetUTF8)
// Set rendered template to body
response.SetBody(buf.Bytes())
return nil
}
func (r *DefaultRes) renderExtensions(bind any) {
r.c.renderExtensions(bind)
}
// Send sets the HTTP response body without copying it.
// From this point onward the body argument must not be changed.
func (r *DefaultRes) Send(body []byte) error {
// Write response body
r.c.fasthttp.Response.SetBodyRaw(body)
return nil
}
// SendEarlyHints allows the server to hint to the browser what resources a page would need
// so the browser can preload them while waiting for the server's full response. Only Link
// headers already written to the response will be transmitted as Early Hints.
//
// This is a HTTP/2+ feature but all browsers will either understand it or safely ignore it.
//
// NOTE: Older HTTP/1.1 non-browser clients may face compatibility issues.
//
// See: https://developer.chrome.com/docs/web-platform/early-hints and
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Link#syntax
func (r *DefaultRes) SendEarlyHints(hints []string) error {
if len(hints) == 0 {
return nil
}
for _, h := range hints {
r.c.fasthttp.Response.Header.Add("Link", h)
}
return r.c.fasthttp.EarlyHints()
}
// SendFile transfers the file from the specified path.
// By default, the file is not compressed. To enable compression, set SendFile.Compress to true.
// The Content-Type response HTTP header field is set based on the file's extension.
// If the file extension is missing or invalid, the Content-Type is detected from the file's format.
func (r *DefaultRes) SendFile(file string, config ...SendFile) error {
// Save the filename, we will need it in the error message if the file isn't found
filename := file
var cfg SendFile
if len(config) > 0 {
cfg = config[0]
}
if cfg.CacheDuration == 0 {
cfg.CacheDuration = 10 * time.Second
}
var fsHandler fasthttp.RequestHandler
var cacheControlValue string
app := r.c.app
app.sendfilesMutex.RLock()
for _, sf := range app.sendfiles {
if sf.configEqual(cfg) {
fsHandler = sf.handler
cacheControlValue = sf.cacheControlValue
break
}
}
app.sendfilesMutex.RUnlock()
if fsHandler == nil {
fasthttpFS := &fasthttp.FS{
Root: "",
FS: cfg.FS,
AllowEmptyRoot: true,
GenerateIndexPages: false,
AcceptByteRange: cfg.ByteRange,
Compress: cfg.Compress,
CompressBrotli: cfg.Compress,
CompressZstd: cfg.Compress,
CompressedFileSuffixes: app.config.CompressedFileSuffixes,
CacheDuration: cfg.CacheDuration,
SkipCache: cfg.CacheDuration < 0,
IndexNames: []string{"index.html"},
PathNotFound: func(ctx *fasthttp.RequestCtx) {
ctx.Response.SetStatusCode(StatusNotFound)
},
}
if cfg.FS != nil {
fasthttpFS.Root = "."
}
sf := &sendFileStore{
config: cfg,
handler: fasthttpFS.NewRequestHandler(),
}
maxAge := cfg.MaxAge
if maxAge > 0 {
sf.cacheControlValue = "public, max-age=" + strconv.Itoa(maxAge)
}
// set vars
fsHandler = sf.handler
cacheControlValue = sf.cacheControlValue
app.sendfilesMutex.Lock()
app.sendfiles = append(app.sendfiles, sf)
app.sendfilesMutex.Unlock()
}
// Keep original path for mutable params
r.c.pathOriginal = utils.CopyString(r.c.pathOriginal)
request := &r.c.fasthttp.Request
// Delete the Accept-Encoding header if compression is disabled
if !cfg.Compress {
// https://github.com/valyala/fasthttp/blob/7cc6f4c513f9e0d3686142e0a1a5aa2f76b3194a/fs.go#L55
request.Header.Del(HeaderAcceptEncoding)
}
// copy of https://github.com/valyala/fasthttp/blob/7cc6f4c513f9e0d3686142e0a1a5aa2f76b3194a/fs.go#L103-L121 with small adjustments
if file == "" || (!filepath.IsAbs(file) && cfg.FS == nil) {
// extend relative path to absolute path
hasTrailingSlash := file != "" && (file[len(file)-1] == '/' || file[len(file)-1] == '\\')
var err error
file = filepath.FromSlash(file)
if file, err = filepath.Abs(file); err != nil {
return fmt.Errorf("failed to determine abs file path: %w", err)
}
if hasTrailingSlash {
file += "/"
}
}
// convert the path to forward slashes regardless the OS in order to set the URI properly
// the handler will convert back to OS path separator before opening the file
file = filepath.ToSlash(file)
// Restore the original requested URL
originalURL := utils.CopyString(r.c.OriginalURL())
defer request.SetRequestURI(originalURL)
// Set new URI for fileHandler
request.SetRequestURI(file)
var (
sendFileSize int64
hasSendFileSize bool
)
if cfg.ByteRange && len(request.Header.Peek(HeaderRange)) > 0 {
sizePath := file
if cfg.FS != nil {
sizePath = filepath.ToSlash(filename)
}
if size, err := sendFileContentLength(sizePath, cfg); err == nil {
sendFileSize = size
hasSendFileSize = true
}
}
// Save status code
response := &r.c.fasthttp.Response
status := response.StatusCode()
// Serve file
fsHandler(r.c.fasthttp)
// Sets the response Content-Disposition header to attachment if the Download option is true
if cfg.Download {
r.Attachment()
}
// Get the status code which is set by fasthttp
fsStatus := response.StatusCode()
// Check for error
if status != StatusNotFound && fsStatus == StatusNotFound {
return NewError(StatusNotFound, fmt.Sprintf("sendfile: file %s not found", filename))
}
// Set the status code set by the user if it is different from the fasthttp status code and 200
if status != fsStatus && status != StatusOK {
r.Status(status)
}
// Apply cache control header
if status != StatusNotFound && status != StatusForbidden {
if cfg.ByteRange && hasSendFileSize && response.StatusCode() == StatusRequestedRangeNotSatisfiable && len(response.Header.Peek(HeaderContentRange)) == 0 {
response.Header.Set(HeaderContentRange, "bytes */"+strconv.FormatInt(sendFileSize, 10))
}
if cacheControlValue != "" {
response.Header.Set(HeaderCacheControl, cacheControlValue)
}
return nil
}
return nil
}
func sendFileContentLength(path string, cfg SendFile) (int64, error) {
if cfg.FS != nil {
cleanPath := pathpkg.Clean(utils.TrimLeft(path, '/'))
if cleanPath == "." {
cleanPath = ""
}
info, err := fs.Stat(cfg.FS, cleanPath)
if err != nil {
return 0, fmt.Errorf("stat %q: %w", cleanPath, err)
}
return info.Size(), nil
}
info, err := os.Stat(filepath.FromSlash(path))
if err != nil {
return 0, fmt.Errorf("stat %q: %w", path, err)
}
return info.Size(), nil
}
// SendStatus sets the HTTP status code and if the response body is empty,
// it sets the correct status message in the body.
func (r *DefaultRes) SendStatus(status int) error {
r.Status(status)
if statusDisallowsBody(status) {
r.c.fasthttp.Response.ResetBody()
return nil
}
// Only set status body when there is no response body
if len(r.c.fasthttp.Response.Body()) == 0 {
return r.SendString(utils.StatusMessage(status))
}
return nil
}
// SendString sets the HTTP response body for string types.
// This means no type assertion, recommended for faster performance
func (r *DefaultRes) SendString(body string) error {
r.c.fasthttp.Response.SetBodyString(body)
return nil
}
// SendStream sets response body stream and optional body size.
func (r *DefaultRes) SendStream(stream io.Reader, size ...int) error {
if len(size) > 0 && size[0] >= 0 {
r.c.fasthttp.Response.SetBodyStream(stream, size[0])
} else {
r.c.fasthttp.Response.SetBodyStream(stream, -1)
}
return nil
}
// SendStreamWriter sets response body stream writer
func (r *DefaultRes) SendStreamWriter(streamWriter func(*bufio.Writer)) error {
r.c.fasthttp.Response.SetBodyStreamWriter(fasthttp.StreamWriter(streamWriter))
return nil
}
// Set sets the response's HTTP header field to the specified key, value.
func (r *DefaultRes) Set(key, val string) {
r.c.fasthttp.Response.Header.Set(key, val)
}
func (r *DefaultRes) setCanonical(key, val string) {
r.c.fasthttp.Response.Header.SetCanonical(utils.UnsafeBytes(key), utils.UnsafeBytes(val))
}
// Status sets the HTTP status for the response.
// This method is chainable.
func (r *DefaultRes) Status(status int) Ctx {
r.c.fasthttp.Response.SetStatusCode(status)
return r.c
}
func statusDisallowsBody(status int) bool {
// As per RFC 9110, 1xx (Informational) responses cannot have a body.
if status >= 100 && status < 200 {
return true
}
switch status {
case StatusNoContent, StatusResetContent, StatusNotModified:
return true
default:
return false
}
}
// Type sets the Content-Type HTTP header to the MIME type specified by the file extension.
func (r *DefaultRes) Type(extension string, charset ...string) Ctx {
mimeType := utils.GetMIME(extension)
if len(charset) > 0 {
r.c.fasthttp.Response.Header.SetContentType(mimeType + "; charset=" + charset[0])
} else {
// Automatically add UTF-8 charset for text-based MIME types
if shouldIncludeCharset(mimeType) {
r.c.fasthttp.Response.Header.SetContentType(mimeType + "; charset=utf-8")
} else {
r.c.fasthttp.Response.Header.SetContentType(mimeType)
}
}
return r.c
}
// shouldIncludeCharset determines if a MIME type should include UTF-8 charset by default
func shouldIncludeCharset(mimeType string) bool {
// Everything under text/ gets UTF-8 by default.
if strings.HasPrefix(mimeType, "text/") {
return true
}
// Explicit application types that should default to UTF-8.
switch mimeType {
case MIMEApplicationJSON,
MIMEApplicationJavaScript,
MIMEApplicationXML:
return true
}
// Any application/*+json or application/*+xml.
if strings.HasSuffix(mimeType, "+json") || strings.HasSuffix(mimeType, "+xml") {
return true
}
return false
}
// Vary adds the given header field to the Vary response header.
// This will append the header, if not already listed; otherwise, leaves it listed in the current location.
func (r *DefaultRes) Vary(fields ...string) {
r.Append(HeaderVary, fields...)
}
// Write appends p into response body.
func (r *DefaultRes) Write(p []byte) (int, error) {
r.c.fasthttp.Response.AppendBody(p)
return len(p), nil
}
// Writef appends f & a into response body writer.
func (r *DefaultRes) Writef(f string, a ...any) (int, error) {
//nolint:wrapcheck // This must not be wrapped
return fmt.Fprintf(r.c.fasthttp.Response.BodyWriter(), f, a...)
}
// WriteString appends s to response body.
func (r *DefaultRes) WriteString(s string) (int, error) {
r.c.fasthttp.Response.AppendBodyString(s)
return len(s), nil
}
// Release is a method to reset Res fields when to use ReleaseCtx()
func (r *DefaultRes) release() {
r.c = nil
}
// Drop closes the underlying connection without sending any response headers or body.
// This can be useful for silently terminating client connections, such as in DDoS mitigation
// or when blocking access to sensitive endpoints.
func (r *DefaultRes) Drop() error {
//nolint:wrapcheck // error wrapping is avoided to keep the operation lightweight and focused on connection closure.
return r.c.fasthttp.Conn().Close()
}
// End immediately flushes the current response and closes the underlying connection.
//
// Note: End does not work when using streaming (e.g. fasthttp's HijackConn or SendStream),
// because in streaming mode the connection is managed asynchronously and ctx.Conn() may return nil.
func (r *DefaultRes) End() error {
ctx := r.c.fasthttp
if ctx == nil {
return nil
}
conn := ctx.Conn()
if conn == nil {
return nil
}
bw := bufio.NewWriter(conn)
if err := ctx.Response.Write(bw); err != nil {
return err
}
if err := bw.Flush(); err != nil {
return err //nolint:wrapcheck // unnecessary to wrap it
}
return conn.Close() //nolint:wrapcheck // unnecessary to wrap it
}
================================================
FILE: res_interface_gen.go
================================================
// Code generated by ifacemaker; DO NOT EDIT.
package fiber
import (
"bufio"
"io"
"github.com/valyala/fasthttp"
)
// Res is an interface for response-related Ctx methods.
type Res interface {
// App returns the *App reference to the instance of the Fiber application
App() *App
// Append the specified value to the HTTP response header field.
// If the header is not already set, it creates the header with the specified value.
Append(field string, values ...string)
// Attachment sets the HTTP response Content-Disposition header field to attachment.
Attachment(filename ...string)
// ClearCookie expires a specific cookie by key on the client side.
// If no key is provided it expires all cookies that came with the request.
ClearCookie(key ...string)
// RequestCtx returns *fasthttp.RequestCtx that carries a deadline
// a cancellation signal, and other values across API boundaries.
RequestCtx() *fasthttp.RequestCtx
// Cookie sets a cookie by passing a cookie struct.
Cookie(cookie *Cookie)
// Download transfers the file from path as an attachment.
// Typically, browsers will prompt the user for download.
// By default, the Content-Disposition header filename= parameter is the filepath (this typically appears in the browser dialog).
// Override this default with the filename parameter.
Download(file string, filename ...string) error
// Response return the *fasthttp.Response object
// This allows you to use all fasthttp response methods
// https://godoc.org/github.com/valyala/fasthttp#Response
Response() *fasthttp.Response
// Format performs content-negotiation on the Accept HTTP header.
// It uses Accepts to select a proper format and calls the matching
// user-provided handler function.
// If no accepted format is found, and a format with MediaType "default" is given,
// that default handler is called. If no format is found and no default is given,
// StatusNotAcceptable is sent.
Format(handlers ...ResFmt) error
// AutoFormat performs content-negotiation on the Accept HTTP header.
// It uses Accepts to select a proper format.
// The supported content types are text/html, text/plain, application/json, application/xml, application/vnd.msgpack, and application/cbor.
// For more flexible content negotiation, use Format.
// If the header is not specified or there is no proper format, text/plain is used.
AutoFormat(body any) error
// Get (a.k.a. GetRespHeader) returns the HTTP response header specified by field.
// Field names are case-insensitive
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
Get(key string, defaultValue ...string) string
// GetHeaders (a.k.a GetRespHeaders) returns the HTTP response headers.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
GetHeaders() map[string][]string
// JSON converts any interface or string to JSON.
// Array and slice values encode as JSON arrays,
// except that []byte encodes as a base64-encoded string,
// and a nil slice encodes as the null JSON value.
// If the ctype parameter is given, this method will set the
// Content-Type header equal to ctype. If ctype is not given,
// The Content-Type header will be set to application/json; charset=utf-8.
JSON(data any, ctype ...string) error
// MsgPack converts any interface or string to MessagePack encoded bytes.
// If the ctype parameter is given, this method will set the
// Content-Type header equal to ctype. If ctype is not given,
// The Content-Type header will be set to application/vnd.msgpack.
MsgPack(data any, ctype ...string) error
// CBOR converts any interface or string to CBOR encoded bytes.
// If the ctype parameter is given, this method will set the
// Content-Type header equal to ctype. If ctype is not given,
// The Content-Type header will be set to application/cbor.
CBOR(data any, ctype ...string) error
// JSONP sends a JSON response with JSONP support.
// This method is identical to JSON, except that it opts-in to JSONP callback support.
// By default, the callback name is simply callback.
JSONP(data any, callback ...string) error
// XML converts any interface or string to XML.
// This method also sets the content header to application/xml; charset=utf-8.
XML(data any) error
// Links joins the links followed by the property to populate the response's Link HTTP header field.
Links(link ...string)
// Location sets the response Location HTTP header to the specified path parameter.
Location(path string)
// OriginalURL contains the original request URL.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting to use the value outside the Handler.
OriginalURL() string
// Redirect returns the Redirect reference.
// Use Redirect().Status() to set custom redirection status code.
// If status is not specified, status defaults to 303 See Other.
// You can use Redirect().To(), Redirect().Route() and Redirect().Back() for redirection.
Redirect() *Redirect
// ViewBind Add vars to default view var map binding to template engine.
// Variables are read by the Render method and may be overwritten.
ViewBind(vars Map) error
// getLocationFromRoute get URL location from route using parameters
getLocationFromRoute(route *Route, params Map) (string, error)
// GetRouteURL generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831"
GetRouteURL(routeName string, params Map) (string, error)
// Render a template with data and sends a text/html response.
// We support the following engines: https://github.com/gofiber/template
Render(name string, bind any, layouts ...string) error
renderExtensions(bind any)
// Send sets the HTTP response body without copying it.
// From this point onward the body argument must not be changed.
Send(body []byte) error
// SendEarlyHints allows the server to hint to the browser what resources a page would need
// so the browser can preload them while waiting for the server's full response. Only Link
// headers already written to the response will be transmitted as Early Hints.
//
// This is a HTTP/2+ feature but all browsers will either understand it or safely ignore it.
//
// NOTE: Older HTTP/1.1 non-browser clients may face compatibility issues.
//
// See: https://developer.chrome.com/docs/web-platform/early-hints and
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Link#syntax
SendEarlyHints(hints []string) error
// SendFile transfers the file from the specified path.
// By default, the file is not compressed. To enable compression, set SendFile.Compress to true.
// The Content-Type response HTTP header field is set based on the file's extension.
// If the file extension is missing or invalid, the Content-Type is detected from the file's format.
SendFile(file string, config ...SendFile) error
// SendStatus sets the HTTP status code and if the response body is empty,
// it sets the correct status message in the body.
SendStatus(status int) error
// SendString sets the HTTP response body for string types.
// This means no type assertion, recommended for faster performance
SendString(body string) error
// SendStream sets response body stream and optional body size.
SendStream(stream io.Reader, size ...int) error
// SendStreamWriter sets response body stream writer
SendStreamWriter(streamWriter func(*bufio.Writer)) error
// Set sets the response's HTTP header field to the specified key, value.
Set(key, val string)
setCanonical(key, val string)
// Status sets the HTTP status for the response.
// This method is chainable.
Status(status int) Ctx
// Type sets the Content-Type HTTP header to the MIME type specified by the file extension.
Type(extension string, charset ...string) Ctx
// Vary adds the given header field to the Vary response header.
// This will append the header, if not already listed; otherwise, leaves it listed in the current location.
Vary(fields ...string)
// Write appends p into response body.
Write(p []byte) (int, error)
// Writef appends f & a into response body writer.
Writef(f string, a ...any) (int, error)
// WriteString appends s to response body.
WriteString(s string) (int, error)
// Release is a method to reset Res fields when to use ReleaseCtx()
release()
// Drop closes the underlying connection without sending any response headers or body.
// This can be useful for silently terminating client connections, such as in DDoS mitigation
// or when blocking access to sensitive endpoints.
Drop() error
// End immediately flushes the current response and closes the underlying connection.
End() error
}
================================================
FILE: router.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 🤖 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"fmt"
"slices"
"sync/atomic"
"github.com/gofiber/utils/v2"
utilsstrings "github.com/gofiber/utils/v2/strings"
"github.com/valyala/fasthttp"
)
// Router defines all router handle interface, including app and group router.
type Router interface {
Use(args ...any) Router
Get(path string, handler any, handlers ...any) Router
Head(path string, handler any, handlers ...any) Router
Post(path string, handler any, handlers ...any) Router
Put(path string, handler any, handlers ...any) Router
Delete(path string, handler any, handlers ...any) Router
Connect(path string, handler any, handlers ...any) Router
Options(path string, handler any, handlers ...any) Router
Trace(path string, handler any, handlers ...any) Router
Patch(path string, handler any, handlers ...any) Router
Add(methods []string, path string, handler any, handlers ...any) Router
All(path string, handler any, handlers ...any) Router
Group(prefix string, handlers ...any) Router
RouteChain(path string) Register
Route(prefix string, fn func(router Router), name ...string) Router
Name(name string) Router
}
// Route is a struct that holds all metadata for each registered handler.
type Route struct {
// ### important: always keep in sync with the copy method "app.copyRoute" and all creations of Route struct ###
group *Group // Group instance. used for routes in groups
path string // Prettified path
// Public fields
Method string `json:"method"` // HTTP method
Name string `json:"name"` // Route's name
//nolint:revive // Having both a Path (uppercase) and a path (lowercase) is fine
Path string `json:"path"` // Original registered route path
Params []string `json:"params"` // Case-sensitive param keys
Handlers []Handler `json:"-"` // Ctx handlers
routeParser routeParser // Parameter parser
// Data for routing
use bool // USE matches path prefixes
mount bool // Indicated a mounted app on a specific route
star bool // Path equals '*'
root bool // Path equals '/'
autoHead bool // Automatically generated HEAD route
}
func (r *Route) match(detectionPath, path string, params *[maxParams]string) bool {
// root detectionPath check
if r.root && len(detectionPath) == 1 && detectionPath[0] == '/' {
return true
}
// '*' wildcard matches any detectionPath
if r.star {
if len(path) > 1 {
params[0] = path[1:]
} else {
params[0] = ""
}
return true
}
// Does this route have parameters?
if len(r.Params) > 0 {
// Match params using precomputed routeParser
if r.routeParser.getMatch(detectionPath, path, params, r.use) {
return true
}
}
// Middleware route?
if r.use {
// Single slash or prefix match
plen := len(r.path)
if r.root {
// If r.root is '/', it matches everything starting at '/'
if detectionPath != "" && detectionPath[0] == '/' {
return true
}
} else if len(detectionPath) >= plen && detectionPath[:plen] == r.path {
if hasPartialMatchBoundary(detectionPath, plen) {
return true
}
}
} else if len(r.path) == len(detectionPath) && detectionPath == r.path {
// Check exact match
return true
}
// No match
return false
}
func (app *App) next(c *DefaultCtx) (bool, error) {
methodInt := c.methodInt
treeHash := c.treePathHash
// Get stack length
tree, ok := app.treeStack[methodInt][treeHash]
if !ok {
tree = app.treeStack[methodInt][0]
}
lenr := len(tree) - 1
indexRoute := c.indexRoute
// Loop over the route stack starting from previous index
for indexRoute < lenr {
// Increment route index
indexRoute++
// Get *Route
route := tree[indexRoute]
if route.mount {
continue
}
// Check if it matches the request path
if !route.match(utils.UnsafeString(c.detectionPath), utils.UnsafeString(c.path), &c.values) {
continue
}
if c.skipNonUseRoutes && !route.use {
continue
}
// Pass route reference and param values
c.route = route
// Non use handler matched
if !route.use {
c.matched = true
}
// Execute first handler of route
if len(route.Handlers) > 0 {
c.indexHandler = 0
c.indexRoute = indexRoute
return true, route.Handlers[0](c)
}
return true, nil // Stop scanning the stack
}
// If c.Next() does not match, return 404
// If no match, scan stack again if other methods match the request
// Moved from app.handler because middleware may break the route chain
if c.matched {
return false, ErrNotFound
}
exists := false
methods := app.config.RequestMethods
for i := range methods {
// Skip original method
if methodInt == i {
continue
}
// Reset stack index
indexRoute := -1
tree, ok := app.treeStack[i][treeHash]
if !ok {
tree = app.treeStack[i][0]
}
// Get stack length
lenr := len(tree) - 1
// Loop over the route stack starting from previous index
for indexRoute < lenr {
// Increment route index
indexRoute++
// Get *Route
route := tree[indexRoute]
// Skip use routes
if route.use {
continue
}
// Check if it matches the request path
// No match, next route
if route.match(utils.UnsafeString(c.detectionPath), utils.UnsafeString(c.path), &c.values) {
// We matched
exists = true
// Add method to Allow header
c.Append(HeaderAllow, methods[i])
// Break stack loop
break
}
}
c.indexRoute = indexRoute
}
if exists {
return false, ErrMethodNotAllowed
}
return false, ErrNotFound
}
func (app *App) nextCustom(c CustomCtx) (bool, error) {
methodInt := c.getMethodInt()
treeHash := c.getTreePathHash()
// Get stack length
tree, ok := app.treeStack[methodInt][treeHash]
if !ok {
tree = app.treeStack[methodInt][0]
}
lenr := len(tree) - 1
indexRoute := c.getIndexRoute()
// Loop over the route stack starting from previous index
for indexRoute < lenr {
// Increment route index
indexRoute++
// Get *Route
route := tree[indexRoute]
if route.mount {
continue
}
// Check if it matches the request path
if !route.match(c.getDetectionPath(), c.Path(), c.getValues()) {
continue
}
if c.getSkipNonUseRoutes() && !route.use {
continue
}
// Pass route reference and param values
c.setRoute(route)
// Non use handler matched
if !route.use {
c.setMatched(true)
}
// Execute first handler of route
if len(route.Handlers) > 0 {
c.setIndexHandler(0)
c.setIndexRoute(indexRoute)
return true, route.Handlers[0](c)
}
return true, nil // Stop scanning the stack
}
// If c.Next() does not match, return 404
// If no match, scan stack again if other methods match the request
// Moved from app.handler because middleware may break the route chain
if c.getMatched() {
return false, ErrNotFound
}
exists := false
methods := app.config.RequestMethods
for i := range methods {
// Skip original method
if methodInt == i {
continue
}
// Reset stack index
indexRoute := -1
tree, ok := app.treeStack[i][treeHash]
if !ok {
tree = app.treeStack[i][0]
}
// Get stack length
lenr := len(tree) - 1
// Loop over the route stack starting from previous index
for indexRoute < lenr {
// Increment route index
indexRoute++
// Get *Route
route := tree[indexRoute]
// Skip use routes
if route.use {
continue
}
// Check if it matches the request path
// No match, next route
if route.match(c.getDetectionPath(), c.Path(), c.getValues()) {
// We matched
exists = true
// Add method to Allow header
c.Append(HeaderAllow, methods[i])
// Break stack loop
break
}
}
c.setIndexRoute(indexRoute)
}
if exists {
return false, ErrMethodNotAllowed
}
return false, ErrNotFound
}
func (app *App) requestHandler(rctx *fasthttp.RequestCtx) {
// Acquire context from the pool
ctx := app.AcquireCtx(rctx)
defer app.ReleaseCtx(ctx)
var err error
// Attempt to match a route and execute the chain
if d, isDefault := ctx.(*DefaultCtx); isDefault {
// Check if the HTTP method is valid
if d.methodInt == -1 {
_ = d.SendStatus(StatusNotImplemented) //nolint:errcheck // Always return nil
return
}
// Optional: check flash messages (hot path, see hasFlashCookie).
if hasFlashCookie(&d.Request().Header) {
d.Redirect().parseAndClearFlashMessages()
}
_, err = app.next(d)
} else {
// Check if the HTTP method is valid
if ctx.getMethodInt() == -1 {
_ = ctx.SendStatus(StatusNotImplemented) //nolint:errcheck // Always return nil
return
}
// Optional: check flash messages (hot path, see hasFlashCookie).
if hasFlashCookie(&ctx.Request().Header) {
ctx.Redirect().parseAndClearFlashMessages()
}
_, err = app.nextCustom(ctx)
}
if err != nil {
if catch := ctx.App().ErrorHandler(ctx, err); catch != nil {
_ = ctx.SendStatus(StatusInternalServerError) //nolint:errcheck // Always return nil
}
return
}
}
func (app *App) addPrefixToRoute(prefix string, route *Route) *Route {
prefixedPath := getGroupPath(prefix, route.Path)
prettyPath := prefixedPath
// Case-sensitive routing, all to lowercase
if !app.config.CaseSensitive {
prettyPath = utilsstrings.ToLower(prettyPath)
}
// Strict routing, remove trailing slashes
if !app.config.StrictRouting && len(prettyPath) > 1 {
prettyPath = utils.TrimRight(prettyPath, '/')
}
route.Path = prefixedPath
route.path = RemoveEscapeChar(prettyPath)
route.routeParser = parseRoute(prettyPath, app.customConstraints...)
route.root = false
route.star = false
return route
}
func (*App) copyRoute(route *Route) *Route {
return &Route{
// Router booleans
use: route.use,
mount: route.mount,
star: route.star,
root: route.root,
autoHead: route.autoHead,
// Path data
path: route.path,
routeParser: route.routeParser,
// Public data
Path: route.Path,
Params: route.Params,
Name: route.Name,
Method: route.Method,
Handlers: route.Handlers,
}
}
func (app *App) normalizePath(path string) string {
if path == "" {
path = "/"
}
if path[0] != '/' {
path = "/" + path
}
if !app.config.CaseSensitive {
path = utilsstrings.ToLower(path)
}
if !app.config.StrictRouting && len(path) > 1 {
path = utils.TrimRight(path, '/')
}
return RemoveEscapeChar(path)
}
// RemoveRoute is used to remove a route from the stack by path.
// If no methods are specified, it will remove the route for all methods defined in the app.
// You should call RebuildTree after using this to ensure consistency of the tree.
func (app *App) RemoveRoute(path string, methods ...string) {
// Normalize same as register uses
norm := app.normalizePath(path)
pathMatchFunc := func(r *Route) bool {
return r.path == norm // compare private normalized path
}
app.deleteRoute(methods, pathMatchFunc)
}
// RemoveRouteByName is used to remove a route from the stack by name.
// If no methods are specified, it will remove the route for all methods defined in the app.
// You should call RebuildTree after using this to ensure consistency of the tree.
func (app *App) RemoveRouteByName(name string, methods ...string) {
matchFunc := func(r *Route) bool { return r.Name == name }
app.deleteRoute(methods, matchFunc)
}
// RemoveRouteFunc is used to remove a route from the stack by a custom match function.
// If no methods are specified, it will remove the route for all methods defined in the app.
// You should call RebuildTree after using this to ensure consistency of the tree.
// Note: The route.Path is original path, not the normalized path.
func (app *App) RemoveRouteFunc(matchFunc func(r *Route) bool, methods ...string) {
app.deleteRoute(methods, matchFunc)
}
func (app *App) deleteRoute(methods []string, matchFunc func(r *Route) bool) {
if len(methods) == 0 {
methods = app.config.RequestMethods
}
app.mutex.Lock()
defer app.mutex.Unlock()
removedUseRoutes := make(map[string]struct{})
for _, method := range methods {
// Uppercase HTTP methods
method = utilsstrings.ToUpper(method)
// Get unique HTTP method identifier
m := app.methodInt(method)
if m == -1 {
continue // Skip invalid HTTP methods
}
for i := len(app.stack[m]) - 1; i >= 0; i-- {
route := app.stack[m][i]
if !matchFunc(route) {
continue // Skip if route does not match
}
app.stack[m] = append(app.stack[m][:i], app.stack[m][i+1:]...)
app.routesRefreshed = true
// Decrement global handler count. In middleware routes, only decrement once
if _, ok := removedUseRoutes[route.path]; (route.use && slices.Equal(methods, app.config.RequestMethods) && !ok) || !route.use {
if route.use {
removedUseRoutes[route.path] = struct{}{}
}
atomic.AddUint32(&app.handlersCount, ^uint32(len(route.Handlers)-1)) //nolint:gosec // G115 - handler count is always small
}
if method == MethodGet && !route.use && !route.mount {
app.pruneAutoHeadRouteLocked(route.path)
}
}
}
}
// pruneAutoHeadRouteLocked removes an automatically generated HEAD route so a
// later explicit registration can take its place without duplicating handler
// chains. The caller must already hold app.mutex.
func (app *App) pruneAutoHeadRouteLocked(path string) {
headIndex := app.methodInt(MethodHead)
if headIndex == -1 {
return
}
norm := app.normalizePath(path)
headStack := app.stack[headIndex]
for i := len(headStack) - 1; i >= 0; i-- {
headRoute := headStack[i]
if headRoute.path != norm || headRoute.mount || headRoute.use || !headRoute.autoHead {
continue
}
app.stack[headIndex] = append(headStack[:i], headStack[i+1:]...)
app.routesRefreshed = true
atomic.AddUint32(&app.handlersCount, ^uint32(len(headRoute.Handlers)-1)) //nolint:gosec // G115 - handler count is always small
return
}
}
func (app *App) register(methods []string, pathRaw string, group *Group, handlers ...Handler) {
// A regular route requires at least one ctx handler
if len(handlers) == 0 && group == nil {
panic(fmt.Sprintf("missing handler/middleware in route: %s\n", pathRaw))
}
// No nil handlers allowed
for _, h := range handlers {
if h == nil {
panic(fmt.Sprintf("nil handler in route: %s\n", pathRaw))
}
}
// Precompute path normalization ONCE
if pathRaw == "" {
pathRaw = "/"
}
if pathRaw[0] != '/' {
pathRaw = "/" + pathRaw
}
pathPretty := pathRaw
if !app.config.CaseSensitive {
pathPretty = utilsstrings.ToLower(pathPretty)
}
if !app.config.StrictRouting && len(pathPretty) > 1 {
pathPretty = utils.TrimRight(pathPretty, '/')
}
pathClean := RemoveEscapeChar(pathPretty)
parsedRaw := parseRoute(pathRaw, app.customConstraints...)
parsedPretty := parseRoute(pathPretty, app.customConstraints...)
isMount := group != nil && group.app != app
for _, method := range methods {
method = utilsstrings.ToUpper(method)
if method != methodUse && app.methodInt(method) == -1 {
panic(fmt.Sprintf("add: invalid http method %s\n", method))
}
isUse := method == methodUse
isStar := pathClean == "/*"
isRoot := pathClean == "/"
route := Route{
use: isUse,
mount: isMount,
star: isStar,
root: isRoot,
path: pathClean,
routeParser: parsedPretty,
Params: parsedRaw.params,
group: group,
Path: pathRaw,
Method: method,
Handlers: handlers,
}
// Increment global handler count
atomic.AddUint32(&app.handlersCount, uint32(len(handlers))) //nolint:gosec // G115 - handler count is always small
// Middleware route matches all HTTP methods
if isUse {
// Add route to all HTTP methods stack
for _, m := range app.config.RequestMethods {
// Create a route copy to avoid duplicates during compression
r := route
app.addRoute(m, &r)
}
} else {
// Add route to stack
app.addRoute(method, &route)
}
}
}
func (app *App) addRoute(method string, route *Route) {
app.mutex.Lock()
defer app.mutex.Unlock()
// Get unique HTTP method identifier
m := app.methodInt(method)
if method == MethodHead && !route.mount && !route.use {
app.pruneAutoHeadRouteLocked(route.path)
}
// prevent identically route registration
l := len(app.stack[m])
if l > 0 && app.stack[m][l-1].Path == route.Path && route.use == app.stack[m][l-1].use && !route.mount && !app.stack[m][l-1].mount {
preRoute := app.stack[m][l-1]
preRoute.Handlers = append(preRoute.Handlers, route.Handlers...)
} else {
route.Method = method
// Add route to the stack
app.stack[m] = append(app.stack[m], route)
app.routesRefreshed = true
}
// Execute onRoute hooks & change latestRoute if not adding mounted route
if !route.mount {
app.latestRoute = route
if err := app.hooks.executeOnRouteHooks(route); err != nil {
panic(err)
}
}
}
func (app *App) ensureAutoHeadRoutes() {
app.mutex.Lock()
defer app.mutex.Unlock()
app.ensureAutoHeadRoutesLocked()
}
func (app *App) ensureAutoHeadRoutesLocked() {
if app.config.DisableHeadAutoRegister {
return
}
headIndex := app.methodInt(MethodHead)
getIndex := app.methodInt(MethodGet)
if headIndex == -1 || getIndex == -1 {
return
}
headStack := app.stack[headIndex]
existing := make(map[string]struct{}, len(headStack))
for _, route := range headStack {
if route.mount || route.use {
continue
}
existing[route.path] = struct{}{}
}
if len(app.stack[getIndex]) == 0 {
return
}
var added bool
for _, route := range app.stack[getIndex] {
if route.mount || route.use {
continue
}
if _, ok := existing[route.path]; ok {
continue
}
headRoute := app.copyRoute(route)
headRoute.group = route.group
headRoute.Method = MethodHead
headRoute.autoHead = true
// Fasthttp automatically omits response bodies when transmitting
// HEAD responses, so the copied GET handler stack can execute
// unchanged while still producing an empty body on the wire.
headStack = append(headStack, headRoute)
existing[route.path] = struct{}{}
app.routesRefreshed = true
added = true
atomic.AddUint32(&app.handlersCount, uint32(len(headRoute.Handlers))) //nolint:gosec // G115 - handler count is always small
app.latestRoute = headRoute
if err := app.hooks.executeOnRouteHooks(headRoute); err != nil {
panic(err)
}
}
if added {
app.stack[headIndex] = headStack
}
}
// RebuildTree rebuilds the prefix tree from the previously registered routes.
// This method is useful when you want to register routes dynamically after the app has started.
// It is not recommended to use this method on production environments because rebuilding
// the tree is performance-intensive and not thread-safe in runtime. Since building the tree
// is only done in the startupProcess of the app, this method does not make sure that the
// routeTree is being safely changed, as it would add a great deal of overhead in the request.
// Latest benchmark results showed a degradation from 82.79 ns/op to 94.48 ns/op and can be found in:
// https://github.com/gofiber/fiber/issues/2769#issuecomment-2227385283
func (app *App) RebuildTree() *App {
app.mutex.Lock()
defer app.mutex.Unlock()
return app.buildTree()
}
// buildTree build the prefix tree from the previously registered routes
func (app *App) buildTree() *App {
// If routes haven't been refreshed, nothing to do
if !app.routesRefreshed {
return app
}
// 1) First loop: determine all possible 3-char prefixes ("treePaths") for each method
for method := range app.config.RequestMethods {
routes := app.stack[method]
treePaths := make([]int, len(routes))
globalCount := 0
prefixCounts := make(map[int]int, len(routes))
for i, route := range routes {
if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= maxDetectionPaths {
treePaths[i] = int(route.routeParser.segs[0].Const[0])<<16 |
int(route.routeParser.segs[0].Const[1])<<8 |
int(route.routeParser.segs[0].Const[2])
}
if treePaths[i] == 0 {
globalCount++
continue
}
prefixCounts[treePaths[i]]++
}
prevBuckets := app.treeStack[method]
tsMap := make(map[int][]*Route, len(prefixCounts)+1)
tsMap[0] = reuseRouteBucket(prevBuckets, 0, globalCount)
for treePath, count := range prefixCounts {
tsMap[treePath] = reuseRouteBucket(prevBuckets, treePath, count+globalCount)
}
for i, route := range routes {
treePath := treePaths[i]
if treePath == 0 {
for bucket := range tsMap {
tsMap[bucket] = append(tsMap[bucket], route)
}
continue
}
tsMap[treePath] = append(tsMap[treePath], route)
}
app.treeStack[method] = tsMap
}
// reset the flag and return
app.routesRefreshed = false
return app
}
func reuseRouteBucket(prev map[int][]*Route, key, capHint int) []*Route {
if bucket, ok := prev[key]; ok && cap(bucket) >= capHint {
return bucket[:0]
}
return make([]*Route, 0, capHint)
}
================================================
FILE: router_test.go
================================================
// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
// 📃 GitHub Repository: https://github.com/gofiber/fiber
// 📌 API Documentation: https://docs.gofiber.io
package fiber
import (
"bufio"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"reflect"
"runtime"
"strings"
"sync"
"testing"
"github.com/gofiber/utils/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
var routesFixture routeJSON
func init() {
dat, err := os.ReadFile("./.github/testdata/testRoutes.json")
if err != nil {
panic(err)
}
if err := json.Unmarshal(dat, &routesFixture); err != nil {
panic(err)
}
}
func Test_Route_Handler_Order(t *testing.T) {
t.Parallel()
app := New()
var order []int
handler1 := func(c Ctx) error {
order = append(order, 1)
return c.Next()
}
handler2 := func(c Ctx) error {
order = append(order, 2)
return c.Next()
}
handler3 := func(c Ctx) error {
order = append(order, 3)
return c.Next()
}
app.Get("/test", handler1, handler2, handler3, func(c Ctx) error {
order = append(order, 4)
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
expectedOrder := []int{1, 2, 3, 4}
require.Equal(t, expectedOrder, order, "Handler order")
}
func Test_hasFlashCookieExactMatch(t *testing.T) {
t.Parallel()
buildRequestWithCookie := func(t *testing.T, cookie string) *fasthttp.Request {
t.Helper()
rawRequest := strings.NewReader(
"GET / HTTP/1.1\r\nHost: localhost\r\nCookie: " + cookie + "\r\n\r\n",
)
req := new(fasthttp.Request)
require.NoError(t, req.Read(bufio.NewReader(rawRequest)))
return req
}
req := buildRequestWithCookie(t, FlashCookieName+"X=not-the-flash-cookie")
require.False(t, hasFlashCookie(&req.Header))
req = buildRequestWithCookie(t, FlashCookieName+"=valid")
require.True(t, hasFlashCookie(&req.Header))
var syntheticReq fasthttp.Request
syntheticReq.Header.Set(HeaderCookie, FlashCookieName+"=valid")
require.False(t, hasFlashCookie(&syntheticReq.Header))
}
func Test_Route_MixedFiberAndHTTPHandlers(t *testing.T) {
t.Parallel()
app := New()
var order []string
fiberBefore := func(c Ctx) error {
order = append(order, "fiber-before")
c.Set("X-Fiber", "1")
return c.Next()
}
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
order = append(order, "http-final")
w.Header().Set("X-HTTP", "true")
_, err := w.Write([]byte("http"))
assert.NoError(t, err)
})
fiberAfter := func(c Ctx) error {
order = append(order, "fiber-after")
return c.SendString("fiber")
}
app.Get("/mixed", fiberBefore, httpHandler, fiberAfter)
resp, err := app.Test(httptest.NewRequest(MethodGet, "/mixed", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
t.Cleanup(func() {
require.NoError(t, resp.Body.Close())
})
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "http", string(body))
require.Equal(t, "true", resp.Header.Get("X-HTTP"))
require.Equal(t, "1", resp.Header.Get("X-Fiber"))
require.Equal(t, []string{"fiber-before", "http-final"}, order)
}
func Test_Route_Group_WithHTTPHandlers(t *testing.T) {
t.Parallel()
app := New()
var order []string
app.Use("/api", func(c Ctx) error {
order = append(order, "app-use")
return c.Next()
})
grp := app.Group("/api", func(c Ctx) error {
order = append(order, "group-middleware")
return c.Next()
})
grp.Get("/users", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
order = append(order, "http-handler")
_, err := w.Write([]byte("users"))
assert.NoError(t, err)
}))
resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/users", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
t.Cleanup(func() {
require.NoError(t, resp.Body.Close())
})
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "users", string(body))
require.Equal(t, []string{"app-use", "group-middleware", "http-handler"}, order)
}
func Test_RouteChain_WithHTTPHandlers(t *testing.T) {
t.Parallel()
app := New()
chain := app.RouteChain("/combo")
chain.Get(func(c Ctx) error {
c.Set("X-Chain", "fiber")
return c.Next()
}, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, err := w.Write([]byte("combo"))
assert.NoError(t, err)
}))
chain.RouteChain("/nested").Get(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("X-Nested", "true")
_, err := w.Write([]byte("nested"))
assert.NoError(t, err)
}))
resp, err := app.Test(httptest.NewRequest(MethodGet, "/combo", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
require.Equal(t, "fiber", resp.Header.Get("X-Chain"))
t.Cleanup(func() {
require.NoError(t, resp.Body.Close())
})
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "combo", string(body))
nestedResp, err := app.Test(httptest.NewRequest(MethodGet, "/combo/nested", http.NoBody))
require.NoError(t, err)
require.Equal(t, 200, nestedResp.StatusCode)
require.Equal(t, "true", nestedResp.Header.Get("X-Nested"))
t.Cleanup(func() {
require.NoError(t, nestedResp.Body.Close())
})
nestedBody, err := io.ReadAll(nestedResp.Body)
require.NoError(t, err)
require.Equal(t, "nested", string(nestedBody))
}
func Test_Route_Match_SameLength(t *testing.T) {
t.Parallel()
app := New()
app.Get("/:param", func(c Ctx) error {
return c.SendString(c.Params("param"))
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/:param", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, ":param", app.toString(body))
// with param
resp, err = app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "test", app.toString(body))
}
func Test_Route_Match_Star(t *testing.T) {
t.Parallel()
app := New()
app.Get("/*", func(c Ctx) error {
return c.SendString(c.Params("*"))
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/*", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "*", app.toString(body))
// with param
resp, err = app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "test", app.toString(body))
// without parameter
route := Route{
star: true,
path: "/*",
routeParser: routeParser{},
}
params := [maxParams]string{}
match := route.match("", "", ¶ms)
require.True(t, match)
require.Equal(t, [maxParams]string{}, params)
// with parameter
match = route.match("/favicon.ico", "/favicon.ico", ¶ms)
require.True(t, match)
require.Equal(t, [maxParams]string{"favicon.ico"}, params)
// without parameter again
match = route.match("", "", ¶ms)
require.True(t, match)
require.Equal(t, [maxParams]string{}, params)
}
func Test_Route_Match_Root(t *testing.T) {
t.Parallel()
app := New()
app.Get("/", func(c Ctx) error {
return c.SendString("root")
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "root", app.toString(body))
}
func Test_Route_Match_Parser(t *testing.T) {
t.Parallel()
app := New()
app.Get("/foo/:ParamName", func(c Ctx) error {
return c.SendString(c.Params("ParamName"))
})
app.Get("/Foobar/*", func(c Ctx) error {
return c.SendString(c.Params("*"))
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/foo/bar", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "bar", app.toString(body))
// with star
resp, err = app.Test(httptest.NewRequest(MethodGet, "/Foobar/test", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "test", app.toString(body))
}
func TestAutoRegisterHeadRoutes(t *testing.T) {
t.Parallel()
cases := []struct {
name string
}{
{name: "auto registers head for get"},
{name: "disable auto register config"},
{name: "explicit head overrides auto head route"},
{name: "auto head for grouped routes"},
{name: "static handler auto head"},
{name: "head without matching route returns 404"},
{name: "late explicit get keeps explicit head"},
{name: "route listing includes auto head"},
{name: "head mirrors status without body"},
}
requireClose := func(tb testing.TB, closer io.Closer) {
tb.Helper()
require.NoError(tb, closer.Close())
}
registerCleanup := func(tb testing.TB, body io.ReadCloser) {
tb.Helper()
tb.Cleanup(func() {
requireClose(tb, body)
})
}
runners := []func(t *testing.T){
func(t *testing.T) {
t.Helper()
app := New()
app.Get("/", func(c Ctx) error {
c.Set("X-Test", "auto")
return c.SendString("Hello")
})
respHead, err := app.Test(httptest.NewRequest(MethodHead, "/", http.NoBody))
require.NoError(t, err)
registerCleanup(t, respHead.Body)
require.Equal(t, StatusOK, respHead.StatusCode)
require.Equal(t, int64(len("Hello")), respHead.ContentLength)
require.Equal(t, "auto", respHead.Header.Get("X-Test"))
body, err := io.ReadAll(respHead.Body)
require.NoError(t, err)
require.Empty(t, body)
respGet, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody))
require.NoError(t, err)
registerCleanup(t, respGet.Body)
require.Equal(t, StatusOK, respGet.StatusCode)
require.Equal(t, int64(len("Hello")), respGet.ContentLength)
data, err := io.ReadAll(respGet.Body)
require.NoError(t, err)
require.Equal(t, "Hello", string(data))
},
func(t *testing.T) {
t.Helper()
app := New(Config{DisableHeadAutoRegister: true})
app.Get("/", func(c Ctx) error {
return c.SendString("Hello")
})
resp, err := app.Test(httptest.NewRequest(MethodHead, "/", http.NoBody))
require.NoError(t, err)
registerCleanup(t, resp.Body)
require.Equal(t, StatusMethodNotAllowed, resp.StatusCode)
},
func(t *testing.T) {
t.Helper()
app := New()
var getCalls int
app.Get("/override", func(c Ctx) error {
getCalls++
return c.SendString("GET")
})
respHead, err := app.Test(httptest.NewRequest(MethodHead, "/override", http.NoBody))
require.NoError(t, err)
require.Equal(t, StatusOK, respHead.StatusCode)
require.Equal(t, 1, getCalls)
requireClose(t, respHead.Body)
var headCalls int
app.Head("/override", func(c Ctx) error {
headCalls++
c.Set("X-Explicit", "true")
return c.SendStatus(StatusNoContent)
})
respHead, err = app.Test(httptest.NewRequest(MethodHead, "/override", http.NoBody))
require.NoError(t, err)
registerCleanup(t, respHead.Body)
require.Equal(t, StatusNoContent, respHead.StatusCode)
require.Equal(t, "true", respHead.Header.Get("X-Explicit"))
body, err := io.ReadAll(respHead.Body)
require.NoError(t, err)
require.Empty(t, body)
require.Equal(t, 1, getCalls)
require.Equal(t, 1, headCalls)
},
func(t *testing.T) {
t.Helper()
app := New()
group := app.Group("/api")
group.Get("/users/:id", func(c Ctx) error {
c.Set("X-User", c.Params("id"))
return c.SendString("grouped")
})
respHead, err := app.Test(httptest.NewRequest(MethodHead, "/api/users/42", http.NoBody))
require.NoError(t, err)
registerCleanup(t, respHead.Body)
require.Equal(t, StatusOK, respHead.StatusCode)
require.Equal(t, "42", respHead.Header.Get("X-User"))
body, err := io.ReadAll(respHead.Body)
require.NoError(t, err)
require.Empty(t, body)
},
func(t *testing.T) {
t.Helper()
const file = "./.github/testdata/testRoutes.json"
content, err := os.ReadFile(file)
require.NoError(t, err)
app := New()
app.Get("/file", func(c Ctx) error {
return c.SendFile(file)
})
respHead, err := app.Test(httptest.NewRequest(MethodHead, "/file", http.NoBody))
require.NoError(t, err)
registerCleanup(t, respHead.Body)
require.Equal(t, StatusOK, respHead.StatusCode)
require.Equal(t, int64(len(content)), respHead.ContentLength)
body, err := io.ReadAll(respHead.Body)
require.NoError(t, err)
require.Empty(t, body)
respGet, err := app.Test(httptest.NewRequest(MethodGet, "/file", http.NoBody))
require.NoError(t, err)
registerCleanup(t, respGet.Body)
data, err := io.ReadAll(respGet.Body)
require.NoError(t, err)
require.Equal(t, content, data)
},
func(t *testing.T) {
t.Helper()
app := New()
resp, err := app.Test(httptest.NewRequest(MethodHead, "/missing", http.NoBody))
require.NoError(t, err)
registerCleanup(t, resp.Body)
require.Equal(t, StatusNotFound, resp.StatusCode)
},
func(t *testing.T) {
t.Helper()
app := New()
var headCalls int
app.Head("/late", func(c Ctx) error {
headCalls++
c.Set("X-Late", "head")
return c.SendStatus(StatusAccepted)
})
var getCalls int
app.Get("/late", func(c Ctx) error {
getCalls++
return c.SendString("ok")
})
respHead, err := app.Test(httptest.NewRequest(MethodHead, "/late", http.NoBody))
require.NoError(t, err)
registerCleanup(t, respHead.Body)
require.Equal(t, StatusAccepted, respHead.StatusCode)
require.Equal(t, "head", respHead.Header.Get("X-Late"))
require.Equal(t, 1, headCalls)
require.Equal(t, 0, getCalls)
respGet, err := app.Test(httptest.NewRequest(MethodGet, "/late", http.NoBody))
require.NoError(t, err)
registerCleanup(t, respGet.Body)
require.Equal(t, StatusOK, respGet.StatusCode)
require.Equal(t, 1, getCalls)
},
func(t *testing.T) {
t.Helper()
app := New()
app.Get("/list", func(c Ctx) error {
return c.SendString("list")
})
app.startupProcess()
routes := app.GetRoutes()
var hasGet, hasHead bool
for _, route := range routes {
if route.Path == "/list" {
if route.Method == MethodGet {
hasGet = true
}
if route.Method == MethodHead {
hasHead = true
}
}
}
require.True(t, hasGet)
require.True(t, hasHead)
},
func(t *testing.T) {
t.Helper()
app := New()
app.Get("/nocontent", func(c Ctx) error {
return c.SendStatus(StatusNoContent)
})
respHead, err := app.Test(httptest.NewRequest(MethodHead, "/nocontent", http.NoBody))
require.NoError(t, err)
registerCleanup(t, respHead.Body)
require.Equal(t, StatusNoContent, respHead.StatusCode)
body, err := io.ReadAll(respHead.Body)
require.NoError(t, err)
require.Empty(t, body)
respGet, err := app.Test(httptest.NewRequest(MethodGet, "/nocontent", http.NoBody))
require.NoError(t, err)
registerCleanup(t, respGet.Body)
require.Equal(t, StatusNoContent, respGet.StatusCode)
},
}
require.Len(t, runners, len(cases))
for i, tc := range cases {
runner := runners[i]
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
runner(t)
})
}
}
func Test_Route_Match_Middleware(t *testing.T) {
t.Parallel()
app := New()
app.Use("/foo/*", func(c Ctx) error {
return c.SendString(c.Params("*"))
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/foo/*", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "*", app.toString(body))
// with param
resp, err = app.Test(httptest.NewRequest(MethodGet, "/foo/bar/fasel", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "bar/fasel", app.toString(body))
}
func Test_Route_Match_UnescapedPath(t *testing.T) {
t.Parallel()
app := New(Config{UnescapePath: true})
app.Use("/créer", func(c Ctx) error {
return c.SendString("test")
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/cr%C3%A9er", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, StatusOK, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "test", app.toString(body))
// without special chars
resp, err = app.Test(httptest.NewRequest(MethodGet, "/créer", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, StatusOK, resp.StatusCode, "Status code")
// check deactivated behavior
app.config.UnescapePath = false
resp, err = app.Test(httptest.NewRequest(MethodGet, "/cr%C3%A9er", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, StatusNotFound, resp.StatusCode, "Status code")
}
func Test_Route_Match_WithEscapeChar(t *testing.T) {
t.Parallel()
app := New()
// static route and escaped part
app.Get("/v1/some/resource/name\\:customVerb", func(c Ctx) error {
return c.SendString("static")
})
// group route
group := app.Group("/v2/\\:firstVerb")
group.Get("/\\:customVerb", func(c Ctx) error {
return c.SendString("group")
})
// route with resource param and escaped part
app.Get("/v3/:resource/name\\:customVerb", func(c Ctx) error {
return c.SendString(c.Params("resource"))
})
// check static route
resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/some/resource/name:customVerb", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, StatusOK, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "static", app.toString(body))
// check group route
resp, err = app.Test(httptest.NewRequest(MethodGet, "/v2/:firstVerb/:customVerb", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, StatusOK, resp.StatusCode, "Status code")
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "group", app.toString(body))
// check param route
resp, err = app.Test(httptest.NewRequest(MethodGet, "/v3/awesome/name:customVerb", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, StatusOK, resp.StatusCode, "Status code")
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "awesome", app.toString(body))
}
func Test_Route_Match_Middleware_HasPrefix(t *testing.T) {
t.Parallel()
app := New()
app.Use("/foo", func(c Ctx) error {
return c.SendString("middleware")
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/foo/bar", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "middleware", app.toString(body))
}
func Test_Route_Match_Middleware_NoBoundary(t *testing.T) {
t.Parallel()
app := New()
app.Use("/foo", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/foobar", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, StatusNotFound, resp.StatusCode, "Status code")
}
func Test_Route_Match_Middleware_Root(t *testing.T) {
t.Parallel()
app := New()
app.Use("/", func(c Ctx) error {
return c.SendString("middleware")
})
resp, err := app.Test(httptest.NewRequest(MethodGet, "/everything", http.NoBody))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, 200, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "middleware", app.toString(body))
}
func Test_Router_Register_Missing_Handler(t *testing.T) {
t.Parallel()
app := New()
t.Run("No Handler", func(t *testing.T) {
t.Parallel()
require.PanicsWithValue(t, "missing handler/middleware in route: /doe\n", func() {
app.register([]string{"USE"}, "/doe", nil)
})
})
t.Run("Nil Handler", func(t *testing.T) {
t.Parallel()
require.PanicsWithValue(t, "nil handler in route: /doe\n", func() {
app.register([]string{"USE"}, "/doe", nil, nil)
})
})
}
func Test_Ensure_Router_Interface_Implementation(t *testing.T) {
t.Parallel()
var app any = (*App)(nil)
_, ok := app.(Router)
require.True(t, ok)
var group any = (*Group)(nil)
_, ok = group.(Router)
require.True(t, ok)
}
func Test_Router_Handler_Catch_Error(t *testing.T) {
t.Parallel()
app := New()
app.config.ErrorHandler = func(_ Ctx, _ error) error {
return errors.New("fake error")
}
app.Get("/", func(_ Ctx) error {
return ErrForbidden
})
c := &fasthttp.RequestCtx{}
app.Handler()(c)
require.Equal(t, StatusInternalServerError, c.Response.Header.StatusCode())
}
func Test_Router_NotFound(t *testing.T) {
t.Parallel()
app := New()
app.Use(func(c Ctx) error {
return c.Next()
})
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/this/route/does/not/exist")
appHandler(c)
require.Equal(t, 404, c.Response.StatusCode())
require.Equal(t, "Not Found", string(c.Response.Body()))
}
func Test_Router_NotFound_HTML_Inject(t *testing.T) {
t.Parallel()
app := New()
app.Use(func(c Ctx) error {
return c.Next()
})
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/does/not/exist")
appHandler(c)
require.Equal(t, 404, c.Response.StatusCode())
require.Equal(t, "Not Found", string(c.Response.Body()))
}
func registerTreeManipulationRoutes(app *App, middleware ...func(Ctx) error) {
converted := make([]any, len(middleware))
for i, h := range middleware {
converted[i] = h
}
app.Get("/test", func(c Ctx) error {
app.Get("/dynamically-defined", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
app.RebuildTree()
return c.SendStatus(StatusOK)
}, converted...)
}
func verifyRequest(tb testing.TB, app *App, path string, expectedStatus int) *http.Response {
tb.Helper()
resp, err := app.Test(httptest.NewRequest(MethodGet, path, http.NoBody))
require.NoError(tb, err, "app.Test(req)")
require.Equal(tb, expectedStatus, resp.StatusCode, "Status code")
return resp
}
func verifyRouteHandlerCounts(tb testing.TB, app *App, expectedRoutesCount int) {
tb.Helper()
// this is taken from listen.go's printRoutesMessage app method
var routes []RouteMessage
for _, routeStack := range app.stack {
for _, route := range routeStack {
routeMsg := RouteMessage{
name: route.Name,
method: route.Method,
path: route.Path,
}
for _, handler := range route.Handlers {
routeMsg.handlers += runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name() + " "
}
routes = append(routes, routeMsg)
}
}
for _, route := range routes {
require.Equal(tb, expectedRoutesCount, strings.Count(route.handlers, " "))
}
}
func verifyThereAreNoRoutes(tb testing.TB, app *App) {
tb.Helper()
require.Equal(tb, uint32(0), app.handlersCount)
verifyRouteHandlerCounts(tb, app, 0)
}
func Test_App_Rebuild_Tree(t *testing.T) {
t.Parallel()
app := New()
registerTreeManipulationRoutes(app)
verifyRequest(t, app, "/dynamically-defined", StatusNotFound)
verifyRequest(t, app, "/test", StatusOK)
verifyRequest(t, app, "/dynamically-defined", StatusOK)
}
func Test_App_Remove_Route_A_B_Feature_Testing(t *testing.T) {
t.Parallel()
app := New()
app.Get("/api/feature-a", func(c Ctx) error {
app.RemoveRoute("/api/feature", MethodGet)
app.RebuildTree()
// Redefine route
app.Get("/api/feature", func(c Ctx) error {
return c.SendString("Testing feature-a")
})
app.RebuildTree()
return c.SendStatus(StatusOK)
})
app.Get("/api/feature-b", func(c Ctx) error {
app.RemoveRoute("/api/feature", MethodGet)
app.RebuildTree()
// Redefine route
app.Get("/api/feature", func(c Ctx) error {
return c.SendString("Testing feature-b")
})
app.RebuildTree()
return c.SendStatus(StatusOK)
})
verifyRequest(t, app, "/api/feature-a", StatusOK)
resp := verifyRequest(t, app, "/api/feature", StatusOK)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "Testing feature-a", string(body), "Response Message")
verifyRequest(t, app, "/api/feature-b", StatusOK)
resp = verifyRequest(t, app, "/api/feature", StatusOK)
body, err = io.ReadAll(resp.Body)
require.NoError(t, err, "app.Test(req)")
require.Equal(t, "Testing feature-b", string(body), "Response Message")
}
func Test_App_Remove_Route_By_Name(t *testing.T) {
t.Parallel()
app := New()
app.Get("/api/test", func(c Ctx) error {
return c.SendStatus(StatusOK)
}).Name("test")
app.RemoveRouteByName("test", MethodGet)
app.RebuildTree()
verifyRequest(t, app, "/test", StatusNotFound)
verifyThereAreNoRoutes(t, app)
}
func Test_App_Remove_Route_By_Name_Non_Existing_Route(t *testing.T) {
t.Parallel()
app := New()
app.RemoveRouteByName("test", MethodGet)
app.RebuildTree()
verifyThereAreNoRoutes(t, app)
}
func Test_App_Remove_Route_Nested(t *testing.T) {
t.Parallel()
app := New()
api := app.Group("/api")
v1 := api.Group("/v1")
v1.Get("/test", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
verifyRequest(t, app, "/api/v1/test", StatusOK)
app.RemoveRoute("/api/v1/test", MethodGet)
verifyThereAreNoRoutes(t, app)
}
func Test_App_Remove_Route_Parameterized(t *testing.T) {
t.Parallel()
app := New()
app.Get("/test/:id", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
verifyRequest(t, app, "/test/:id", StatusOK)
app.RemoveRoute("/test/:id", MethodGet)
verifyThereAreNoRoutes(t, app)
}
func Test_App_Remove_Route(t *testing.T) {
t.Parallel()
app := New()
app.Get("/test", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
app.RemoveRoute("/test", MethodGet)
app.RebuildTree()
verifyRequest(t, app, "/test", StatusNotFound)
}
func Test_App_Remove_Route_Non_Existing_Route(t *testing.T) {
t.Parallel()
app := New()
app.RemoveRoute("/test", MethodGet, MethodHead)
app.RebuildTree()
verifyThereAreNoRoutes(t, app)
}
func Test_App_Use_StrictRoutingBoundary(t *testing.T) {
type testCase struct {
name string
path string
expectedStatus int
strictRouting bool
expectMatched bool
}
testCases := []testCase{
{
name: "Strict exact match",
strictRouting: true,
path: "/api",
expectMatched: true,
expectedStatus: StatusOK,
},
{
name: "Strict trailing slash partial",
strictRouting: true,
path: "/api/",
expectMatched: true,
expectedStatus: StatusOK,
},
{
name: "Strict nested partial",
strictRouting: true,
path: "/api/users",
expectMatched: true,
expectedStatus: StatusOK,
},
{
name: "Strict disallows sibling prefix",
strictRouting: true,
path: "/apiv1",
expectMatched: false,
expectedStatus: StatusNotFound,
},
{
name: "Non-strict exact match",
strictRouting: false,
path: "/api",
expectMatched: true,
expectedStatus: StatusOK,
},
{
name: "Non-strict trailing slash partial",
strictRouting: false,
path: "/api/",
expectMatched: true,
expectedStatus: StatusOK,
},
{
name: "Non-strict nested partial",
strictRouting: false,
path: "/api/users",
expectMatched: true,
expectedStatus: StatusOK,
},
{
name: "Non-strict disallows sibling prefix",
strictRouting: false,
path: "/apiv1",
expectMatched: false,
expectedStatus: StatusNotFound,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
app := New(Config{StrictRouting: tt.strictRouting})
matched := false
app.Use("/api", func(c Ctx) error {
matched = true
return c.SendStatus(StatusOK)
})
resp, err := app.Test(httptest.NewRequest(MethodGet, tt.path, http.NoBody))
require.NoError(t, err)
require.Equal(t, tt.expectedStatus, resp.StatusCode)
require.Equal(t, tt.expectMatched, matched)
})
}
}
func Test_Group_Use_StrictRoutingBoundary(t *testing.T) {
type testCase struct {
name string
path string
expectedStatus int
strictRouting bool
expectMatched bool
}
testCases := []testCase{
{
name: "Strict group exact match",
strictRouting: true,
path: "/api/v1",
expectMatched: true,
expectedStatus: StatusOK,
},
{
name: "Strict group trailing slash partial",
strictRouting: true,
path: "/api/v1/",
expectMatched: true,
expectedStatus: StatusOK,
},
{
name: "Strict group nested partial",
strictRouting: true,
path: "/api/v1/users",
expectMatched: true,
expectedStatus: StatusOK,
},
{
name: "Strict group disallows sibling prefix",
strictRouting: true,
path: "/api/v1beta",
expectMatched: false,
expectedStatus: StatusNotFound,
},
{
name: "Non-strict group exact match",
strictRouting: false,
path: "/api/v1",
expectMatched: true,
expectedStatus: StatusOK,
},
{
name: "Non-strict group trailing slash partial",
strictRouting: false,
path: "/api/v1/",
expectMatched: true,
expectedStatus: StatusOK,
},
{
name: "Non-strict group nested partial",
strictRouting: false,
path: "/api/v1/users",
expectMatched: true,
expectedStatus: StatusOK,
},
{
name: "Non-strict group disallows sibling prefix",
strictRouting: false,
path: "/api/v1beta",
expectMatched: false,
expectedStatus: StatusNotFound,
},
}
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
app := New(Config{StrictRouting: tt.strictRouting})
grp := app.Group("/api")
matched := false
grp.Use("/v1", func(c Ctx) error {
matched = true
return c.Next()
})
grp.Get("/v1", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
grp.Get("/v1/*", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
resp, err := app.Test(httptest.NewRequest(MethodGet, tt.path, http.NoBody))
require.NoError(t, err)
require.Equal(t, tt.expectedStatus, resp.StatusCode)
require.Equal(t, tt.expectMatched, matched)
})
}
}
func Test_App_Remove_Route_Concurrent(t *testing.T) {
t.Parallel()
app := New()
// Add test route
app.Get("/test", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
// Concurrently remove and add routes
var wg sync.WaitGroup
for range 10 {
wg.Go(func() {
app.RemoveRoute("/test", MethodGet)
app.Get("/test", func(c Ctx) error {
return c.SendStatus(StatusOK)
})
})
}
wg.Wait()
// Verify final state
app.RebuildTree()
verifyRequest(t, app, "/test", StatusOK)
}
func Test_Route_Registration_Prevent_Duplicate_With_Middleware(t *testing.T) {
t.Parallel()
app := New()
middleware := func(c Ctx) error {
return c.Next()
}
registerTreeManipulationRoutes(app, middleware)
registerTreeManipulationRoutes(app)
verifyRequest(t, app, "/dynamically-defined", StatusNotFound)
require.Equal(t, uint32(6), app.handlersCount)
verifyRequest(t, app, "/test", StatusOK)
require.Equal(t, uint32(7), app.handlersCount)
verifyRequest(t, app, "/dynamically-defined", StatusOK)
require.Equal(t, uint32(8), app.handlersCount)
verifyRequest(t, app, "/test", StatusOK)
require.Equal(t, uint32(9), app.handlersCount)
verifyRequest(t, app, "/dynamically-defined", StatusOK)
require.Equal(t, uint32(9), app.handlersCount)
}
func TestNormalizePath(t *testing.T) {
tests := []struct {
name string
path string
expected string
caseSensitive bool
strictRouting bool
}{
{
name: "Empty path",
path: "",
caseSensitive: true,
strictRouting: true,
expected: "/",
},
{
name: "No leading slash",
path: "users",
caseSensitive: true,
strictRouting: true,
expected: "/users",
},
{
name: "With trailing slash and strict routing",
path: "/users/",
caseSensitive: true,
strictRouting: true,
expected: "/users/",
},
{
name: "With trailing slash and non-strict routing",
path: "/users/",
caseSensitive: true,
strictRouting: false,
expected: "/users",
},
{
name: "Case sensitive",
path: "/Users",
caseSensitive: true,
strictRouting: true,
expected: "/Users",
},
{
name: "Case insensitive",
path: "/Users",
caseSensitive: false,
strictRouting: true,
expected: "/users",
},
{
name: "With escape characters",
path: "/users\\/profile",
caseSensitive: true,
strictRouting: true,
expected: "/users/profile",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
app := &App{
config: Config{
CaseSensitive: tt.caseSensitive,
StrictRouting: tt.strictRouting,
},
}
result := app.normalizePath(tt.path)
require.Equal(t, tt.expected, result)
})
}
}
func TestRemoveRoute(t *testing.T) {
app := New()
var buf strings.Builder
app.Use(func(c Ctx) error {
buf.WriteString("1") //nolint:errcheck // not needed
return c.Next()
})
app.Post("/", func(c Ctx) error {
buf.WriteString("2") //nolint:errcheck // not needed
return c.SendStatus(StatusOK)
})
app.Use("/test", func(c Ctx) error {
buf.WriteString("3") //nolint:errcheck // not needed
return c.Next()
})
app.Get("/test", func(c Ctx) error {
buf.WriteString("4") //nolint:errcheck // not needed
return c.SendStatus(StatusOK)
})
app.Post("/test", func(c Ctx) error {
buf.WriteString("5") //nolint:errcheck // not needed
return c.SendStatus(StatusOK)
})
app.startupProcess()
require.Equal(t, uint32(6), app.handlersCount)
req, err := http.NewRequestWithContext(context.Background(), MethodPost, "/", http.NoBody)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
require.Equal(t, "12", buf.String())
buf.Reset()
req, err = http.NewRequestWithContext(context.Background(), MethodGet, "/test", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
require.Equal(t, "134", buf.String())
buf.Reset()
require.Equal(t, uint32(6), app.handlersCount)
app.RemoveRoute("/test", MethodGet)
app.RebuildTree()
app.RemoveRoute("/test", "TEST")
app.RebuildTree()
app.RemoveRouteFunc(func(_ *Route) bool {
return false
}, MethodGet)
req, err = http.NewRequestWithContext(context.Background(), MethodGet, "/test", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
buf.Reset()
require.Equal(t, StatusMethodNotAllowed, resp.StatusCode)
require.Equal(t, uint32(4), app.handlersCount)
app.RemoveRoute("/test", MethodPost)
app.RebuildTree()
require.Equal(t, uint32(3), app.handlersCount)
req, err = http.NewRequestWithContext(context.Background(), MethodPost, "/test", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, 404, resp.StatusCode)
require.Equal(t, "1", buf.String())
buf.Reset()
req, err = http.NewRequestWithContext(context.Background(), MethodGet, "/test", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, 404, resp.StatusCode)
require.Equal(t, "1", buf.String())
buf.Reset()
app.RemoveRoute("/", MethodGet, MethodPost)
require.Equal(t, uint32(2), app.handlersCount)
req, err = http.NewRequestWithContext(context.Background(), MethodGet, "/", http.NoBody)
require.NoError(t, err)
resp, err = app.Test(req)
require.NoError(t, err)
require.Equal(t, 404, resp.StatusCode)
require.Empty(t, buf.String())
buf.Reset()
app.RemoveRoute("/test", MethodGet, MethodPost)
require.Equal(t, uint32(2), app.handlersCount)
app.RemoveRoute("/test", app.config.RequestMethods...)
require.Equal(t, uint32(1), app.handlersCount)
app.Patch("/test", func(c Ctx) error {
buf.WriteString("6") //nolint:errcheck // not needed
return c.SendStatus(StatusOK)
})
require.Equal(t, uint32(2), app.handlersCount)
app.RemoveRoute("/test")
app.RemoveRoute("/")
app.RebuildTree()
require.Equal(t, uint32(0), app.handlersCount)
}
//////////////////////////////////////////////
///////////////// BENCHMARKS /////////////////
//////////////////////////////////////////////
func registerDummyRoutes(app *App) {
h := func(_ Ctx) error {
return nil
}
for _, r := range routesFixture.GitHubAPI {
app.Add([]string{r.Method}, r.Path, h)
}
}
func acquireDefaultCtxForRouterBenchmark(b *testing.B, app *App, fctx *fasthttp.RequestCtx) *DefaultCtx {
b.Helper()
ctx := app.AcquireCtx(fctx)
defaultCtx, ok := ctx.(*DefaultCtx)
if !ok {
b.Fatal("AcquireCtx did not return *DefaultCtx")
}
return defaultCtx
}
// go test -v -run=^$ -bench=Benchmark_App_RebuildTree -benchmem -count=4
func Benchmark_App_RebuildTree(b *testing.B) {
app := New()
registerDummyRoutes(app)
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
app.routesRefreshed = true
app.RebuildTree()
}
}
// go test -v -run=^$ -bench=Benchmark_App_MethodNotAllowed -benchmem -count=4
func Benchmark_App_MethodNotAllowed(b *testing.B) {
app := New()
h := func(c Ctx) error {
return c.SendString("Hello World!")
}
app.All("/this/is/a/", h)
app.Get("/this/is/a/dummy/route/oke", h)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/this/is/a/dummy/route/oke")
for b.Loop() {
appHandler(c)
}
require.Equal(b, 405, c.Response.StatusCode())
require.Equal(b, MethodGet+", "+MethodHead, string(c.Response.Header.Peek("Allow")))
require.Equal(b, utils.StatusMessage(StatusMethodNotAllowed), string(c.Response.Body()))
}
// go test -v ./... -run=^$ -bench=Benchmark_Router_NotFound -benchmem -count=4
func Benchmark_Router_NotFound(b *testing.B) {
b.ReportAllocs()
app := New()
app.Use(func(c Ctx) error {
return c.Next()
})
registerDummyRoutes(app)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/this/route/does/not/exist")
for b.Loop() {
appHandler(c)
}
require.Equal(b, 404, c.Response.StatusCode())
require.Equal(b, "Not Found", string(c.Response.Body()))
}
// go test -v ./... -run=^$ -bench=Benchmark_Router_Handler -benchmem -count=4
func Benchmark_Router_Handler(b *testing.B) {
app := New()
registerDummyRoutes(app)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/user/keys/1337")
for b.Loop() {
appHandler(c)
}
}
func Benchmark_Router_Handler_Strict_Case(b *testing.B) {
app := New(Config{
StrictRouting: true,
CaseSensitive: true,
})
registerDummyRoutes(app)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/user/keys/1337")
for b.Loop() {
appHandler(c)
}
}
// go test -v ./... -run=^$ -bench=Benchmark_Router_Chain -benchmem -count=4
func Benchmark_Router_Chain(b *testing.B) {
app := New()
handler := func(c Ctx) error {
return c.Next()
}
app.Get("/", handler, handler, handler, handler, handler, handler)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod(MethodGet)
c.URI().SetPath("/")
for b.Loop() {
appHandler(c)
}
}
// go test -v ./... -run=^$ -bench=Benchmark_Router_WithCompression -benchmem -count=4
func Benchmark_Router_WithCompression(b *testing.B) {
app := New()
handler := func(c Ctx) error {
return c.Next()
}
app.Get("/", handler)
app.Get("/", handler)
app.Get("/", handler)
app.Get("/", handler)
app.Get("/", handler)
app.Get("/", handler)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod(MethodGet)
c.URI().SetPath("/")
for b.Loop() {
appHandler(c)
}
}
// go test -run=^$ -bench=Benchmark_Startup_Process -benchmem -count=9
func Benchmark_Startup_Process(b *testing.B) {
for b.Loop() {
app := New()
registerDummyRoutes(app)
app.startupProcess()
}
}
// go test -v ./... -run=^$ -bench=Benchmark_Router_Next -benchmem -count=4
func Benchmark_Router_Next(b *testing.B) {
app := New()
registerDummyRoutes(app)
app.startupProcess()
request := &fasthttp.RequestCtx{}
request.Request.Header.SetMethod("DELETE")
request.URI().SetPath("/user/keys/1337")
var res bool
var err error
c := app.AcquireCtx(request).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
for b.Loop() {
c.indexRoute = -1
res, err = app.next(c)
}
require.NoError(b, err)
require.True(b, res)
require.Equal(b, 4, c.indexRoute)
}
// go test -v ./... -run=^$ -bench=Benchmark_Router_Next_Default -benchmem -count=4
func Benchmark_Router_Next_Default(b *testing.B) {
app := New()
app.Get("/", func(_ Ctx) error {
return nil
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(MethodGet)
fctx.Request.SetRequestURI("/")
b.ReportAllocs()
for b.Loop() {
h(fctx)
}
}
// go test -benchmem -run=^$ -bench ^Benchmark_Router_Next_Default_Parallel$ github.com/gofiber/fiber/v3 -count=1
func Benchmark_Router_Next_Default_Parallel(b *testing.B) {
app := New()
app.Get("/", func(_ Ctx) error {
return nil
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(MethodGet)
fctx.Request.SetRequestURI("/")
for pb.Next() {
h(fctx)
}
})
}
// go test -v ./... -run=^$ -bench=Benchmark_Router_Next_Default_Immutable -benchmem -count=4
func Benchmark_Router_Next_Default_Immutable(b *testing.B) {
app := New(Config{Immutable: true})
app.Get("/", func(_ Ctx) error {
return nil
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(MethodGet)
fctx.Request.SetRequestURI("/")
b.ReportAllocs()
for b.Loop() {
h(fctx)
}
}
// go test -benchmem -run=^$ -bench ^Benchmark_Router_Next_Default_Parallel_Immutable$ github.com/gofiber/fiber/v3 -count=1
func Benchmark_Router_Next_Default_Parallel_Immutable(b *testing.B) {
app := New(Config{Immutable: true})
app.Get("/", func(_ Ctx) error {
return nil
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(MethodGet)
fctx.Request.SetRequestURI("/")
for pb.Next() {
h(fctx)
}
})
}
// go test -v ./... -run=^$ -bench=Benchmark_Route_Match -benchmem -count=4
func Benchmark_Route_Match(b *testing.B) {
var match bool
var params [maxParams]string
parsed := parseRoute("/user/keys/:id")
route := &Route{
use: false,
root: false,
star: false,
routeParser: parsed,
Params: parsed.params,
path: "/user/keys/:id",
Path: "/user/keys/:id",
Method: "DELETE",
}
route.Handlers = append(route.Handlers, func(_ Ctx) error {
return nil
})
for b.Loop() {
match = route.match("/user/keys/1337", "/user/keys/1337", ¶ms)
}
require.True(b, match)
require.Equal(b, []string{"1337"}, params[0:len(parsed.params)])
}
// go test -v ./... -run=^$ -bench=Benchmark_Route_Match_Star -benchmem -count=4
func Benchmark_Route_Match_Star(b *testing.B) {
var match bool
var params [maxParams]string
parsed := parseRoute("/*")
route := &Route{
use: false,
root: false,
star: true,
routeParser: parsed,
Params: parsed.params,
path: "/user/keys/bla",
Path: "/user/keys/bla",
Method: "DELETE",
}
route.Handlers = append(route.Handlers, func(_ Ctx) error {
return nil
})
for b.Loop() {
match = route.match("/user/keys/bla", "/user/keys/bla", ¶ms)
}
require.True(b, match)
require.Equal(b, []string{"user/keys/bla"}, params[0:len(parsed.params)])
}
// go test -v ./... -run=^$ -bench=Benchmark_Route_Match_Root -benchmem -count=4
func Benchmark_Route_Match_Root(b *testing.B) {
var match bool
var params [maxParams]string
parsed := parseRoute("/")
route := &Route{
use: false,
root: true,
star: false,
path: "/",
routeParser: parsed,
Params: parsed.params,
Path: "/",
Method: "DELETE",
}
route.Handlers = append(route.Handlers, func(_ Ctx) error {
return nil
})
for b.Loop() {
match = route.match("/", "/", ¶ms)
}
require.True(b, match)
require.Equal(b, []string{}, params[0:len(parsed.params)])
}
// go test -v ./... -run=^$ -bench=Benchmark_Router_Handler_CaseSensitive -benchmem -count=4
func Benchmark_Router_Handler_CaseSensitive(b *testing.B) {
app := New()
app.config.CaseSensitive = true
registerDummyRoutes(app)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/user/keys/1337")
for b.Loop() {
appHandler(c)
}
}
// go test -v ./... -run=^$ -bench=Benchmark_Router_Handler_Unescape -benchmem -count=4
func Benchmark_Router_Handler_Unescape(b *testing.B) {
app := New()
app.config.UnescapePath = true
registerDummyRoutes(app)
app.Delete("/créer", func(_ Ctx) error {
return nil
})
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod(MethodDelete)
c.URI().SetPath("/cr%C3%A9er")
for b.Loop() {
c.URI().SetPath("/cr%C3%A9er")
appHandler(c)
}
}
// go test -run=^$ -bench=Benchmark_Router_Handler_StrictRouting -benchmem -count=4
func Benchmark_Router_Handler_StrictRouting(b *testing.B) {
app := New()
app.config.CaseSensitive = true
registerDummyRoutes(app)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/user/keys/1337")
for b.Loop() {
appHandler(c)
}
}
// go test -run=^$ -bench=Benchmark_Router_GitHub_API -benchmem -count=16
func Benchmark_Router_GitHub_API(b *testing.B) {
app := New()
registerDummyRoutes(app)
app.startupProcess()
c := &fasthttp.RequestCtx{}
var match bool
var err error
b.ResetTimer()
for i := range routesFixture.TestRoutes {
b.RunParallel(func(pb *testing.PB) {
c.Request.Header.SetMethod(routesFixture.TestRoutes[i].Method)
for pb.Next() {
c.URI().SetPath(routesFixture.TestRoutes[i].Path)
ctx := app.AcquireCtx(c).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed
match, err = app.next(ctx)
app.ReleaseCtx(ctx)
}
})
require.NoError(b, err)
require.True(b, match)
}
}
type testRoute struct {
Method string `json:"method"`
Path string `json:"path"`
}
type routeJSON struct {
TestRoutes []testRoute `json:"test_routes"`
GitHubAPI []testRoute `json:"github_api"`
}
func newCustomApp() *App {
return NewWithCustomCtx(func(app *App) CustomCtx {
return &customCtx{DefaultCtx: *NewDefaultCtx(app)}
})
}
func Test_NextCustom_MethodNotAllowed(t *testing.T) {
t.Parallel()
app := newCustomApp()
app.Get("/foo", func(c Ctx) error { return c.SendStatus(StatusOK) })
useRoute := &Route{use: true, path: "/foo", Path: "/foo", routeParser: parseRoute("/foo")}
m := app.methodInt(MethodGet)
app.stack[m] = append([]*Route{useRoute}, app.stack[m]...)
app.routesRefreshed = true
app.ensureAutoHeadRoutes()
app.RebuildTree()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(MethodPost)
fctx.Request.SetRequestURI("/foo")
ctx := app.AcquireCtx(fctx)
defer app.ReleaseCtx(ctx)
matched, err := app.nextCustom(ctx)
require.False(t, matched)
require.ErrorIs(t, err, ErrMethodNotAllowed)
allow := string(ctx.Response().Header.Peek(HeaderAllow))
require.Equal(t, "GET, HEAD", allow)
}
func Test_NextCustom_NotFound(t *testing.T) {
t.Parallel()
app := newCustomApp()
app.RebuildTree()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(MethodGet)
fctx.Request.SetRequestURI("/not-exist")
ctx := app.AcquireCtx(fctx)
defer app.ReleaseCtx(ctx)
matched, err := app.nextCustom(ctx)
require.False(t, matched)
var e *Error
require.ErrorAs(t, err, &e)
require.Equal(t, StatusNotFound, e.Code)
}
func Test_RequestHandler_CustomCtx_NotImplemented(t *testing.T) {
t.Parallel()
app := newCustomApp()
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod("UNKNOWN")
fctx.Request.SetRequestURI("/")
h(fctx)
require.Equal(t, StatusNotImplemented, fctx.Response.StatusCode())
}
func Test_NextCustom_Matched404(t *testing.T) {
t.Parallel()
app := newCustomApp()
app.RebuildTree()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(MethodGet)
fctx.Request.SetRequestURI("/none")
ctx := app.AcquireCtx(fctx)
ctx.setMatched(true)
defer app.ReleaseCtx(ctx)
matched, err := app.nextCustom(ctx)
require.False(t, matched)
var e *Error
require.ErrorAs(t, err, &e)
require.Equal(t, StatusNotFound, e.Code)
}
func Test_NextCustom_SkipMountAndNoHandlers(t *testing.T) {
t.Parallel()
app := newCustomApp()
m := app.methodInt(MethodGet)
mountR := &Route{path: "/skip", Path: "/skip", routeParser: parseRoute("/skip"), mount: true}
empty := &Route{path: "/foo", Path: "/foo", routeParser: parseRoute("/foo")}
app.stack[m] = []*Route{mountR, empty}
app.routesRefreshed = true
app.RebuildTree()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(MethodGet)
fctx.Request.SetRequestURI("/foo")
ctx := app.AcquireCtx(fctx)
defer app.ReleaseCtx(ctx)
matched, err := app.nextCustom(ctx)
require.True(t, matched)
require.NoError(t, err)
require.Equal(t, "/foo", ctx.Route().Path)
}
func Test_AddRoute_MergeHandlers(t *testing.T) {
t.Parallel()
app := New()
count := func(_ Ctx) error { return nil }
app.Get("/merge", count)
app.Get("/merge", count)
require.Len(t, app.stack[app.methodInt(MethodGet)], 1)
require.Len(t, app.stack[app.methodInt(MethodGet)][0].Handlers, 2)
}
func Benchmark_App_RebuildTree_Parallel(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
// Each worker gets its own App instance to avoid data races on shared state
localApp := New()
registerDummyRoutes(localApp)
for pb.Next() {
localApp.routesRefreshed = true
localApp.RebuildTree()
}
})
}
func Benchmark_App_MethodNotAllowed_Parallel(b *testing.B) {
app := New()
h := func(c Ctx) error {
return c.SendString("Hello World!")
}
app.All("/this/is/a/", h)
app.Get("/this/is/a/dummy/route/oke", h)
appHandler := app.Handler()
b.RunParallel(func(pb *testing.PB) {
// Each worker gets its own RequestCtx to avoid data races
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/this/is/a/dummy/route/oke")
for pb.Next() {
appHandler(c)
}
})
// Single-threaded verification on a fresh context to preserve correctness checks
verifyCtx := &fasthttp.RequestCtx{}
verifyCtx.Request.Header.SetMethod("DELETE")
verifyCtx.URI().SetPath("/this/is/a/dummy/route/oke")
appHandler(verifyCtx)
require.Equal(b, 405, verifyCtx.Response.StatusCode())
require.Equal(b, MethodGet+", "+MethodHead, string(verifyCtx.Response.Header.Peek("Allow")))
require.Equal(b, utils.StatusMessage(StatusMethodNotAllowed), string(verifyCtx.Response.Body()))
}
func Benchmark_Router_NotFound_Parallel(b *testing.B) {
b.ReportAllocs()
app := New()
app.Use(func(c Ctx) error {
return c.Next()
})
registerDummyRoutes(app)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/this/route/does/not/exist")
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
appHandler(c)
}
})
require.Equal(b, 404, c.Response.StatusCode())
require.Equal(b, "Not Found", string(c.Response.Body()))
}
func Benchmark_Router_Handler_Parallel(b *testing.B) {
app := New()
registerDummyRoutes(app)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/user/keys/1337")
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
appHandler(c)
}
})
}
func Benchmark_Router_Handler_Strict_Case_Parallel(b *testing.B) {
app := New(Config{StrictRouting: true, CaseSensitive: true})
registerDummyRoutes(app)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/user/keys/1337")
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
appHandler(c)
}
})
}
func Benchmark_Router_Chain_Parallel(b *testing.B) {
app := New()
handler := func(c Ctx) error {
return c.Next()
}
app.Get("/", handler, handler, handler, handler, handler, handler)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod(MethodGet)
c.URI().SetPath("/")
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
appHandler(c)
}
})
}
func Benchmark_Router_WithCompression_Parallel(b *testing.B) {
app := New()
handler := func(c Ctx) error {
return c.Next()
}
app.Get("/", handler)
app.Get("/", handler)
app.Get("/", handler)
app.Get("/", handler)
app.Get("/", handler)
app.Get("/", handler)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod(MethodGet)
c.URI().SetPath("/")
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
appHandler(c)
}
})
}
func Benchmark_Startup_Process_Parallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
app := New()
registerDummyRoutes(app)
app.startupProcess()
}
})
}
func Benchmark_Router_Next_Parallel(b *testing.B) {
app := New()
registerDummyRoutes(app)
app.startupProcess()
b.RunParallel(func(pb *testing.PB) {
// Each worker gets its own request and context to avoid data races.
request := &fasthttp.RequestCtx{}
request.Request.Header.SetMethod("DELETE")
request.URI().SetPath("/user/keys/1337")
c := acquireDefaultCtxForRouterBenchmark(b, app, request)
for pb.Next() {
c.indexRoute = -1
//nolint:errcheck // Benchmark hot path - error checked in verification
_, _ = app.next(c)
}
})
// Single-threaded verification on a fresh context to preserve correctness checks.
verifyRequest := &fasthttp.RequestCtx{}
verifyRequest.Request.Header.SetMethod("DELETE")
verifyRequest.URI().SetPath("/user/keys/1337")
verifyCtx := acquireDefaultCtxForRouterBenchmark(b, app, verifyRequest)
verifyCtx.indexRoute = -1
res, err := app.next(verifyCtx)
require.NoError(b, err)
require.True(b, res)
require.Equal(b, 4, verifyCtx.indexRoute)
}
func Benchmark_Router_Next_Default_Immutable_Parallel(b *testing.B) {
app := New(Config{Immutable: true})
app.Get("/", func(_ Ctx) error {
return nil
})
h := app.Handler()
fctx := &fasthttp.RequestCtx{}
fctx.Request.Header.SetMethod(MethodGet)
fctx.Request.SetRequestURI("/")
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
h(fctx)
}
})
}
func Benchmark_Route_Match_Parallel(b *testing.B) {
parsed := parseRoute("/user/keys/:id")
route := &Route{use: false, root: false, star: false, routeParser: parsed, Params: parsed.params, path: "/user/keys/:id", Path: "/user/keys/:id", Method: "DELETE"}
route.Handlers = append(route.Handlers, func(_ Ctx) error {
return nil
})
b.RunParallel(func(pb *testing.PB) {
// Each worker gets its own local variables to avoid data races
var params [maxParams]string
for pb.Next() {
_ = route.match("/user/keys/1337", "/user/keys/1337", ¶ms)
}
})
// Single-threaded verification to preserve correctness checks
var verifyParams [maxParams]string
match := route.match("/user/keys/1337", "/user/keys/1337", &verifyParams)
require.True(b, match)
require.Equal(b, []string{"1337"}, verifyParams[0:len(parsed.params)])
}
func Benchmark_Route_Match_Star_Parallel(b *testing.B) {
var match bool
var params [maxParams]string
parsed := parseRoute("/*")
route := &Route{use: false, root: false, star: true, routeParser: parsed, Params: parsed.params, path: "/user/keys/bla", Path: "/user/keys/bla", Method: "DELETE"}
route.Handlers = append(route.Handlers, func(_ Ctx) error {
return nil
})
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
match = route.match("/user/keys/bla", "/user/keys/bla", ¶ms)
}
})
require.True(b, match)
require.Equal(b, []string{"user/keys/bla"}, params[0:len(parsed.params)])
}
func Benchmark_Route_Match_Root_Parallel(b *testing.B) {
var match bool
var params [maxParams]string
parsed := parseRoute("/")
route := &Route{use: false, root: true, star: false, path: "/", routeParser: parsed, Params: parsed.params, Path: "/", Method: "DELETE"}
route.Handlers = append(route.Handlers, func(_ Ctx) error {
return nil
})
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
match = route.match("/", "/", ¶ms)
}
})
require.True(b, match)
require.Equal(b, []string{}, params[0:len(parsed.params)])
}
func Benchmark_Router_Handler_CaseSensitive_Parallel(b *testing.B) {
app := New()
app.config.CaseSensitive = true
registerDummyRoutes(app)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/user/keys/1337")
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
appHandler(c)
}
})
}
func Benchmark_Router_Handler_Unescape_Parallel(b *testing.B) {
app := New()
app.config.UnescapePath = true
registerDummyRoutes(app)
app.Delete("/créer", func(_ Ctx) error {
return nil
})
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod(MethodDelete)
c.URI().SetPath("/cr%C3%A9er")
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
c.URI().SetPath("/cr%C3%A9er")
appHandler(c)
}
})
}
func Benchmark_Router_Handler_StrictRouting_Parallel(b *testing.B) {
app := New()
app.config.CaseSensitive = true
registerDummyRoutes(app)
appHandler := app.Handler()
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod("DELETE")
c.URI().SetPath("/user/keys/1337")
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
appHandler(c)
}
})
}
func Benchmark_Router_GitHub_API_Parallel(b *testing.B) {
app := New()
registerDummyRoutes(app)
app.startupProcess()
b.ResetTimer()
for i := range routesFixture.TestRoutes {
b.RunParallel(func(pb *testing.PB) {
// Each worker gets its own RequestCtx and local variables to avoid data races
c := &fasthttp.RequestCtx{}
c.Request.Header.SetMethod(routesFixture.TestRoutes[i].Method)
for pb.Next() {
c.URI().SetPath(routesFixture.TestRoutes[i].Path)
ctx := acquireDefaultCtxForRouterBenchmark(b, app, c)
//nolint:errcheck // Benchmark hot path - error checked in verification
_, _ = app.next(ctx)
app.ReleaseCtx(ctx)
}
})
// Single-threaded verification on a fresh context to preserve correctness checks
verifyC := &fasthttp.RequestCtx{}
verifyC.Request.Header.SetMethod(routesFixture.TestRoutes[i].Method)
verifyC.URI().SetPath(routesFixture.TestRoutes[i].Path)
verifyCtx := acquireDefaultCtxForRouterBenchmark(b, app, verifyC)
match, err := app.next(verifyCtx)
app.ReleaseCtx(verifyCtx)
require.NoError(b, err)
require.True(b, match)
}
}
================================================
FILE: services.go
================================================
package fiber
import (
"context"
"errors"
"fmt"
"io"
utilsstrings "github.com/gofiber/utils/v2/strings"
)
// Service is an interface that defines the methods for a service.
type Service interface {
// Start starts the service, returning an error if it fails.
Start(ctx context.Context) error
// String returns a string representation of the service.
// It is used to print a human-readable name of the service in the startup message.
String() string
// State returns the current state of the service.
State(ctx context.Context) (string, error)
// Terminate terminates the service, returning an error if it fails.
Terminate(ctx context.Context) error
}
// hasConfiguredServices Checks if there are any services for the current application.
func (app *App) hasConfiguredServices() bool {
return len(app.configured.Services) > 0
}
func (app *App) validateConfiguredServices() error {
return validateServicesSlice(app.configured.Services)
}
func validateServicesSlice(services []Service) error {
for idx, srv := range services {
if srv == nil {
return fmt.Errorf("fiber: service at index %d is nil", idx)
}
}
return nil
}
// initServices If the app is configured to use services, this function registers
// a post shutdown hook to shutdown them after the server is closed.
// This function panics if there is an error starting the services.
func (app *App) initServices() {
if !app.hasConfiguredServices() {
return
}
if err := app.startServices(app.servicesStartupCtx()); err != nil {
panic(err)
}
}
// servicesStartupCtx Returns the context for the services startup.
// If the ServicesStartupContextProvider is not set, it returns a new background context.
func (app *App) servicesStartupCtx() context.Context {
if app.configured.ServicesStartupContextProvider != nil {
return app.configured.ServicesStartupContextProvider()
}
return context.Background()
}
// servicesShutdownCtx Returns the context for the services shutdown.
// If the ServicesShutdownContextProvider is not set, it returns a new background context.
func (app *App) servicesShutdownCtx() context.Context {
if app.configured.ServicesShutdownContextProvider != nil {
return app.configured.ServicesShutdownContextProvider()
}
return context.Background()
}
// startServices Handles the start process of services for the current application.
// Iterates over all configured services and tries to start them, returning an error if any error occurs.
func (app *App) startServices(ctx context.Context) error {
if !app.hasConfiguredServices() {
return nil
}
var errs []error
for idx, srv := range app.configured.Services {
if srv == nil {
return fmt.Errorf("fiber: service at index %d is nil", idx)
}
if err := ctx.Err(); err != nil {
// Context is canceled, return an error the soonest possible, so that
// the user can see the context cancellation error and act on it.
return fmt.Errorf("context canceled while starting service %s: %w", srv.String(), err)
}
err := srv.Start(ctx)
if err == nil {
// mark the service as started
app.state.setService(srv)
continue
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("service %s start: %w", srv.String(), err)
}
errs = append(errs, fmt.Errorf("service %s start: %w", srv.String(), err))
}
return errors.Join(errs...)
}
// shutdownServices Handles the shutdown process of services for the current application.
// Iterates over all the started services in reverse order and tries to terminate them,
// returning an error if any error occurs.
func (app *App) shutdownServices(ctx context.Context) error {
if app.state.ServicesLen() == 0 {
return nil
}
var errs []error
for key, srv := range app.state.Services() {
if srv == nil {
return fmt.Errorf("fiber: service %q is nil", key)
}
if err := ctx.Err(); err != nil {
// Context is canceled, do a best effort to terminate the services.
errs = append(errs, fmt.Errorf("service %s terminate: %w", srv.String(), err))
continue
}
err := srv.Terminate(ctx)
if err != nil {
// Best effort to terminate the services.
errs = append(errs, fmt.Errorf("service %s terminate: %w", srv.String(), err))
continue
}
// Remove the service from the State
app.state.deleteService(srv)
}
return errors.Join(errs...)
}
// logServices logs information about services and returns an error
// if any configured service is nil.
func (app *App) logServices(ctx context.Context, out io.Writer, colors *Colors) error {
if !app.hasConfiguredServices() {
return nil
}
scheme := colors
if scheme == nil {
scheme = &DefaultColors
}
fmt.Fprintf(out,
"%sINFO%s Services: \t%s%d%s\n",
scheme.Green, scheme.Reset, scheme.Blue, app.state.ServicesLen(), scheme.Reset)
for key, srv := range app.state.Services() {
if srv == nil {
return fmt.Errorf("fiber: service %q is nil", key)
}
var state string
var stateColor string
state, err := srv.State(ctx)
if err != nil {
state = errString
stateColor = scheme.Red
} else {
stateColor = scheme.Blue
}
fmt.Fprintf(out, "%sINFO%s 🧩 %s[ %s ] %s%s\n", scheme.Green, scheme.Reset, stateColor, utilsstrings.ToUpper(state), srv.String(), scheme.Reset)
}
return nil
}
================================================
FILE: services_test.go
================================================
package fiber
import (
"bytes"
"context"
"errors"
"fmt"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v3/log"
"github.com/stretchr/testify/require"
)
const (
terminateErrorMessage = "terminate error"
startErrorMessage = "start error"
)
// mockService implements Service interface for testing
type mockService struct {
startError error
terminateError error
stateError error
name string
started bool
terminated bool
startDelay time.Duration
terminateDelay time.Duration
}
func (m *mockService) Start(ctx context.Context) error {
select {
case <-ctx.Done():
return fmt.Errorf("context canceled: %w", ctx.Err())
default:
}
if m.startDelay > 0 {
timer := time.NewTimer(m.startDelay)
select {
case <-ctx.Done():
timer.Stop()
return fmt.Errorf("context canceled: %w", ctx.Err())
case <-timer.C:
// Continue after delay
}
}
if m.startError != nil {
m.started = false
return m.startError
}
m.started = true
return nil
}
func (m *mockService) String() string {
return m.name
}
func (m *mockService) State(ctx context.Context) (string, error) {
if ctx.Err() != nil {
return "", fmt.Errorf("context canceled: %w", ctx.Err())
}
if m.stateError != nil {
return "error", m.stateError
}
if m.started {
return "running", nil
}
if m.terminated {
return "stopped", nil
}
return "unknown", nil
}
func (m *mockService) Terminate(ctx context.Context) error {
select {
case <-ctx.Done():
return fmt.Errorf("context canceled: %w", ctx.Err())
default:
}
if m.terminateDelay > 0 {
timer := time.NewTimer(m.terminateDelay)
select {
case <-ctx.Done():
timer.Stop()
return fmt.Errorf("context canceled: %w", ctx.Err())
case <-timer.C:
// Continue after delay
}
}
if m.terminateError != nil {
m.terminated = false
return m.terminateError
}
m.started = false
m.terminated = true
return nil
}
func Test_HasConfiguredServices(t *testing.T) {
testHasConfiguredServicesFn := func(t *testing.T, app *App, expected bool) {
t.Helper()
result := app.hasConfiguredServices()
require.Equal(t, expected, result)
}
t.Run("no-services", func(t *testing.T) {
testHasConfiguredServicesFn(t, &App{configured: Config{}}, false)
})
t.Run("has-services", func(t *testing.T) {
testHasConfiguredServicesFn(t, &App{configured: Config{Services: []Service{&mockService{name: "test-dep"}}}}, true)
})
}
func Test_InitServices(t *testing.T) {
t.Run("no-services", func(t *testing.T) {
app := &App{configured: Config{}}
require.NotPanics(t, app.initServices)
})
t.Run("start/success", func(t *testing.T) {
// Initialize the app using the struct and defining the state and hooks manually,
// because we are not checking the shutdown hooks in this test.
app := &App{
configured: Config{
Services: []Service{
&mockService{name: "dep1"},
&mockService{name: "dep2"},
},
},
state: newState(),
}
app.hooks = newHooks(app)
require.NotPanics(t, app.initServices)
})
t.Run("start/error", func(t *testing.T) {
// Initialize the app using the struct and defining the state and hooks manually,
// because we are not checking the shutdown hooks in this test.
app := &App{
configured: Config{
Services: []Service{
&mockService{name: "dep1", startError: errors.New(startErrorMessage + " 1")},
&mockService{name: "dep2", startError: errors.New(startErrorMessage + " 2")},
&mockService{name: "dep3"},
},
},
state: newState(),
}
app.hooks = newHooks(app)
require.Panics(t, app.initServices)
})
t.Run("shutdown-hooks/success", func(t *testing.T) {
// Initialize the app using the New function to verify that the shutdown hooks are registered
// and the app mutex is not causing a deadlock.
app := New(Config{
Services: []Service{&mockService{name: "dep1"}},
})
require.NotPanics(t, app.initServices)
type stringsLogger struct {
strings.Builder
}
var buf stringsLogger
log.SetOutput(&buf)
app.Hooks().executeOnPostShutdownHooks(nil)
require.NotContains(t, buf.String(), "failed to call post shutdown hook:")
})
t.Run("shutdown-hooks/error", func(t *testing.T) {
// Initialize the app using the New function to verify that the shutdown hooks are registered
// and the app mutex is not causing a deadlock.
app := New(Config{
Services: []Service{
&mockService{name: "dep1"},
&mockService{name: "dep2", terminateError: errors.New(terminateErrorMessage + " 2")},
},
})
require.NotPanics(t, app.initServices)
type stringsLogger struct {
strings.Builder
}
var buf stringsLogger
log.SetOutput(&buf)
app.Hooks().executeOnPostShutdownHooks(nil)
require.Contains(t, buf.String(), "failed to shutdown services: service dep2 terminate: terminate error 2")
})
}
func Test_StartServices(t *testing.T) {
t.Run("no-services", func(t *testing.T) {
app := &App{
configured: Config{
Services: []Service{},
},
state: newState(),
}
err := app.startServices(context.Background())
require.NoError(t, err)
require.Zero(t, app.state.ServicesLen())
})
t.Run("successful-start", func(t *testing.T) {
app := &App{
configured: Config{
Services: []Service{
&mockService{name: "dep1"},
&mockService{name: "dep2"},
},
},
state: newState(),
}
err := app.startServices(context.Background())
require.NoError(t, err)
require.Equal(t, 2, app.state.ServicesLen())
})
t.Run("failed-start", func(t *testing.T) {
app := &App{
configured: Config{
Services: []Service{
&mockService{name: "dep1", startError: errors.New(startErrorMessage + " 1")},
&mockService{name: "dep2", startError: errors.New(startErrorMessage + " 2")},
&mockService{name: "dep3"},
},
},
state: newState(),
}
err := app.startServices(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), startErrorMessage+" 1")
require.Contains(t, err.Error(), startErrorMessage+" 2")
require.Equal(t, 1, app.state.ServicesLen())
})
t.Run("context", func(t *testing.T) {
t.Run("already-canceled", func(t *testing.T) {
app := &App{
configured: Config{
Services: []Service{
&mockService{name: "dep1"},
},
},
state: newState(),
}
// Create a context that is already canceled
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := app.startServices(ctx)
require.ErrorIs(t, err, context.Canceled)
require.Zero(t, app.state.ServicesLen())
})
t.Run("cancellation", func(t *testing.T) {
// Create a service that takes some time to start
slowDep := &mockService{name: "slow-dep", startDelay: 200 * time.Millisecond}
app := &App{
configured: Config{
Services: []Service{slowDep},
},
state: newState(),
}
// Create a context that will be canceled immediately
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
// Start services with a delay that is longer than the timeout
err := app.startServices(ctx)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Zero(t, app.state.ServicesLen())
})
})
}
func Test_ShutdownServices(t *testing.T) {
t.Run("no-services", func(t *testing.T) {
app := &App{
configured: Config{
Services: []Service{},
},
state: newState(),
}
err := app.shutdownServices(context.Background())
require.NoError(t, err)
require.Zero(t, app.state.ServicesLen())
})
t.Run("successful-shutdown", func(t *testing.T) {
srv1 := &mockService{name: "dep1"}
srv2 := &mockService{name: "dep2"}
// Expected state, including the two started services
expectedState := newState()
expectedState.setService(srv1)
expectedState.setService(srv2)
app := &App{
configured: Config{
Services: []Service{srv1, srv2},
},
state: expectedState,
}
err := app.shutdownServices(context.Background())
require.NoError(t, err)
require.Zero(t, app.state.ServicesLen())
})
t.Run("failed-shutdown", func(t *testing.T) {
srv1 := &mockService{name: "dep1", terminateError: errors.New(terminateErrorMessage + " 1")}
srv2 := &mockService{name: "dep2", terminateError: errors.New(terminateErrorMessage + " 2")}
srv3 := &mockService{name: "dep3"}
// Expected state, including the two started services
expectedState := newState()
expectedState.setService(srv1)
expectedState.setService(srv2)
expectedState.setService(srv3)
app := &App{
configured: Config{
Services: []Service{srv1, srv2, srv3},
},
state: expectedState,
}
err := app.shutdownServices(context.Background())
require.Error(t, err)
require.Contains(t, err.Error(), terminateErrorMessage+" 1")
require.Contains(t, err.Error(), terminateErrorMessage+" 2")
require.Equal(t, 2, app.state.ServicesLen()) // 2 services failed to terminate
})
t.Run("context", func(t *testing.T) {
t.Run("already-canceled", func(t *testing.T) {
srv1 := &mockService{name: "dep1"}
// Expected state, including the two started services
expectedState := newState()
expectedState.setService(srv1)
app := &App{
configured: Config{
Services: []Service{srv1},
},
state: expectedState,
}
// Create a context that is already canceled
ctx, cancel := context.WithCancel(context.Background())
cancel()
err := app.shutdownServices(ctx)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
require.Contains(t, err.Error(), "service dep1 terminate")
require.Equal(t, 1, app.state.ServicesLen())
})
t.Run("cancellation", func(t *testing.T) {
// Create a service that takes some time to terminate
slowDep := &mockService{name: "slow-dep", terminateDelay: 200 * time.Millisecond}
// Expected state, including the two started services
expectedState := newState()
expectedState.setService(slowDep)
app := &App{
configured: Config{
Services: []Service{slowDep},
},
state: expectedState,
}
// Create a new context for shutdown
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
// Shutdown services with canceled context
err := app.shutdownServices(ctx)
require.ErrorIs(t, err, context.DeadlineExceeded)
require.Equal(t, 1, app.state.ServicesLen())
})
})
}
func Test_LogServices(t *testing.T) {
t.Parallel()
// Service with successful State
runningService := &mockService{name: "running", started: true}
// Service with State error
errorService := &mockService{name: "error", stateError: errors.New("state error")}
expectedState := newState()
expectedState.setService(runningService)
expectedState.setService(errorService)
app := &App{
configured: Config{
Services: []Service{runningService, errorService},
},
state: expectedState,
}
var buf bytes.Buffer
colors := Colors{
Green: "\033[32m",
Reset: "\033[0m",
Blue: "\033[34m",
Red: "\033[31m",
}
err := app.logServices(context.Background(), &buf, &colors)
require.NoError(t, err)
output := buf.String()
for _, srv := range app.state.Services() {
stateColor := colors.Blue
state := "RUNNING"
if _, err := srv.State(context.Background()); err != nil {
stateColor = colors.Red
state = "ERROR"
}
expected := fmt.Sprintf("%sINFO%s 🧩 %s[ %s ] %s%s\n", colors.Green, colors.Reset, stateColor, strings.ToUpper(state), srv.String(), colors.Reset)
require.Contains(t, output, expected)
}
}
func Test_NewConfiguredServicesNil(t *testing.T) {
t.Parallel()
require.PanicsWithError(t, "fiber: service at index 0 is nil", func() {
New(Config{
Services: []Service{nil},
})
})
}
func Test_ServiceContextProviders(t *testing.T) {
t.Run("no-provider", func(t *testing.T) {
app := &App{
configured: Config{},
state: newState(),
}
require.Equal(t, context.Background(), app.servicesStartupCtx())
require.Equal(t, context.Background(), app.servicesShutdownCtx())
})
t.Run("with-provider", func(t *testing.T) {
ctx := context.TODO()
app := &App{
configured: Config{
ServicesStartupContextProvider: func() context.Context {
return ctx
},
ServicesShutdownContextProvider: func() context.Context {
return ctx
},
},
state: newState(),
}
require.Equal(t, ctx, app.servicesStartupCtx())
require.Equal(t, ctx, app.servicesShutdownCtx())
})
}
func Benchmark_StartServices(b *testing.B) {
benchmarkFn := func(b *testing.B, services []Service) {
b.Helper()
for b.Loop() {
app := New(Config{
Services: services,
})
ctx := context.Background()
if err := app.startServices(ctx); err != nil {
b.Fatal("Expected no error but got", err)
}
}
}
b.Run("no-services", func(b *testing.B) {
benchmarkFn(b, []Service{})
})
b.Run("single-service", func(b *testing.B) {
benchmarkFn(b, []Service{
&mockService{name: "dep1"},
})
})
b.Run("multiple-services", func(b *testing.B) {
benchmarkFn(b, []Service{
&mockService{name: "dep1"},
&mockService{name: "dep2"},
&mockService{name: "dep3"},
})
})
b.Run("multiple-services-with-delays", func(b *testing.B) {
benchmarkFn(b, []Service{
&mockService{name: "dep1", startDelay: 1 * time.Millisecond},
&mockService{name: "dep2", startDelay: 2 * time.Millisecond},
&mockService{name: "dep3", startDelay: 3 * time.Millisecond},
})
})
}
func Benchmark_ShutdownServices(b *testing.B) {
benchmarkFn := func(b *testing.B, services []Service) {
b.Helper()
for b.Loop() {
app := New(Config{
Services: services,
})
ctx := context.Background()
if err := app.shutdownServices(ctx); err != nil {
b.Fatal("Expected no error but got", err)
}
}
}
b.Run("no-services", func(b *testing.B) {
benchmarkFn(b, []Service{})
})
b.Run("single-service", func(b *testing.B) {
benchmarkFn(b, []Service{
&mockService{name: "dep1"},
})
})
b.Run("multiple-services", func(b *testing.B) {
benchmarkFn(b, []Service{
&mockService{name: "dep1"},
&mockService{name: "dep2"},
&mockService{name: "dep3"},
})
})
b.Run("multiple-services-with-delays", func(b *testing.B) {
benchmarkFn(b, []Service{
&mockService{name: "dep1", terminateDelay: 1 * time.Millisecond},
&mockService{name: "dep2", terminateDelay: 2 * time.Millisecond},
&mockService{name: "dep3", terminateDelay: 3 * time.Millisecond},
})
})
}
func Benchmark_StartServices_withContextCancellation(b *testing.B) {
benchmarkFn := func(b *testing.B, services []Service, timeout time.Duration) {
b.Helper()
for b.Loop() {
app := New(Config{
Services: services,
})
ctx, cancel := context.WithTimeout(context.Background(), timeout)
err := app.startServices(ctx)
// We expect an error here due to the short timeout
if err == nil && timeout < time.Second {
b.Fatal("Expected error due to context cancellation but got none")
}
cancel()
}
}
b.Run("single-service/immediate-cancellation", func(b *testing.B) {
benchmarkFn(b, []Service{
&mockService{name: "dep1", startDelay: 100 * time.Millisecond},
}, 10*time.Millisecond)
})
b.Run("multiple-services/immediate-cancellation", func(b *testing.B) {
benchmarkFn(b, []Service{
&mockService{name: "dep1", startDelay: 100 * time.Millisecond},
&mockService{name: "dep2", startDelay: 200 * time.Millisecond},
&mockService{name: "dep3", startDelay: 300 * time.Millisecond},
}, 10*time.Millisecond)
})
b.Run("multiple-services/successful-completion", func(b *testing.B) {
const timeout = 500 * time.Millisecond
for b.Loop() {
app := New(Config{
Services: []Service{
&mockService{name: "dep1", startDelay: 10 * time.Millisecond},
&mockService{name: "dep2", startDelay: 20 * time.Millisecond},
&mockService{name: "dep3", startDelay: 30 * time.Millisecond},
},
})
ctx, cancel := context.WithTimeout(context.Background(), timeout)
err := app.startServices(ctx)
if err != nil {
b.Fatal("Expected no error but got", err)
}
cancel()
}
})
}
func Benchmark_ShutdownServices_withContextCancellation(b *testing.B) {
benchmarkFn := func(b *testing.B, services []Service, timeout time.Duration) {
b.Helper()
for b.Loop() {
app := New(Config{
Services: services,
})
err := app.startServices(context.Background())
if err != nil {
b.Fatal("Expected no error during startup but got", err)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
err = app.shutdownServices(ctx)
// We expect an error here due to the short timeout
if err == nil && timeout < time.Second {
b.Fatal("Expected error due to context cancellation but got none")
}
cancel()
}
}
b.Run("single-service/immediate-cancellation", func(b *testing.B) {
benchmarkFn(b, []Service{
&mockService{name: "dep1", terminateDelay: 100 * time.Millisecond},
}, 10*time.Millisecond)
})
b.Run("multiple-services/immediate-cancellation", func(b *testing.B) {
benchmarkFn(b, []Service{
&mockService{name: "dep1", terminateDelay: 100 * time.Millisecond},
&mockService{name: "dep2", terminateDelay: 200 * time.Millisecond},
&mockService{name: "dep3", terminateDelay: 300 * time.Millisecond},
}, 10*time.Millisecond)
})
b.Run("multiple-services/successful-completion", func(b *testing.B) {
const timeout = 500 * time.Millisecond
for b.Loop() {
app := New(Config{
Services: []Service{
&mockService{name: "dep1", terminateDelay: 10 * time.Millisecond},
&mockService{name: "dep2", terminateDelay: 20 * time.Millisecond},
&mockService{name: "dep3", terminateDelay: 30 * time.Millisecond},
},
})
err := app.startServices(context.Background())
if err != nil {
b.Fatal("Expected no error but got", err)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
err = app.shutdownServices(ctx)
if err != nil {
b.Fatal("Expected no error but got", err)
}
cancel()
}
})
}
func Benchmark_ServicesMemory(b *testing.B) {
benchmarkFn := func(b *testing.B, services []Service) {
b.Helper()
b.ReportAllocs()
var err error
for b.Loop() {
app := New(Config{
Services: services,
})
ctx := context.Background()
err = app.startServices(ctx)
if err != nil {
continue
}
err = app.shutdownServices(ctx)
}
require.NoError(b, err)
}
b.Run("no-services", func(b *testing.B) {
benchmarkFn(b, []Service{})
})
b.Run("single-service", func(b *testing.B) {
benchmarkFn(b, []Service{
&mockService{name: "dep1"},
})
})
b.Run("multiple-services", func(b *testing.B) {
benchmarkFn(b, []Service{
&mockService{name: "dep1"},
&mockService{name: "dep2"},
&mockService{name: "dep3"},
})
})
}
================================================
FILE: state.go
================================================
package fiber
import (
"encoding/hex"
"strings"
"sync"
"github.com/google/uuid"
)
const servicesStatePrefix = "gofiber-services-"
var servicesStatePrefixHash string
func init() {
servicesStatePrefixHash = hex.EncodeToString([]byte(servicesStatePrefix + uuid.New().String()))
}
// State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies.
// It's a thread-safe implementation of a map[string]any, using sync.Map.
type State struct {
dependencies sync.Map
servicePrefix string
}
// NewState creates a new instance of State.
func newState() *State {
// Initialize the services state prefix using a hashed random string
return &State{
dependencies: sync.Map{},
servicePrefix: servicesStatePrefixHash,
}
}
// Set sets a key-value pair in the State.
func (s *State) Set(key string, value any) {
s.dependencies.Store(key, value)
}
// Get retrieves a value from the State.
func (s *State) Get(key string) (any, bool) {
return s.dependencies.Load(key)
}
// MustGet retrieves a value from the State and panics if the key is not found.
func (s *State) MustGet(key string) any {
if dep, ok := s.Get(key); ok {
return dep
}
panic("state: dependency not found!")
}
// Has checks if a key is present in the State.
// It returns a boolean indicating if the key is present.
func (s *State) Has(key string) bool {
_, ok := s.Get(key)
return ok
}
// Delete removes a key-value pair from the State.
func (s *State) Delete(key string) {
s.dependencies.Delete(key)
}
// Reset resets the State by removing all keys.
func (s *State) Reset() {
s.dependencies.Clear()
}
// Keys returns a slice containing all keys present in the State.
func (s *State) Keys() []string {
// Pre-allocate with estimated capacity to reduce allocations
keys := make([]string, 0, 8)
s.dependencies.Range(func(key, _ any) bool {
keyStr, ok := key.(string)
if !ok {
return true
}
keys = append(keys, keyStr)
return true
})
return keys
}
// Len returns the number of keys in the State.
func (s *State) Len() int {
length := 0
s.dependencies.Range(func(_, _ any) bool {
length++
return true
})
return length
}
// GetState retrieves a value from the State and casts it to the desired type.
// It returns the casted value and a boolean indicating if the cast was successful.
func GetState[T any](s *State, key string) (T, bool) {
dep, ok := s.Get(key)
if ok {
depT, okCast := dep.(T)
return depT, okCast
}
var zeroVal T
return zeroVal, false
}
// MustGetState retrieves a value from the State and casts it to the desired type.
// It panics if the key is not found or if the type assertion fails.
func MustGetState[T any](s *State, key string) T {
dep, ok := GetState[T](s, key)
if !ok {
panic("state: dependency not found!")
}
return dep
}
// GetStateWithDefault retrieves a value from the State,
// casting it to the desired type. If the key is not present,
// it returns the provided default value.
func GetStateWithDefault[T any](s *State, key string, defaultVal T) T {
dep, ok := GetState[T](s, key)
if !ok {
return defaultVal
}
return dep
}
// GetString retrieves a string value from the State.
// It returns the string and a boolean indicating successful type assertion.
func (s *State) GetString(key string) (string, bool) {
return GetState[string](s, key)
}
// GetInt retrieves an integer value from the State.
// It returns the int and a boolean indicating successful type assertion.
func (s *State) GetInt(key string) (int, bool) {
return GetState[int](s, key)
}
// GetBool retrieves a boolean value from the State.
// It returns the bool and a boolean indicating successful type assertion.
func (s *State) GetBool(key string) (value, ok bool) { //nolint:nonamedreturns // Better idea to use named returns here
return GetState[bool](s, key)
}
// GetFloat64 retrieves a float64 value from the State.
// It returns the float64 and a boolean indicating successful type assertion.
func (s *State) GetFloat64(key string) (float64, bool) {
return GetState[float64](s, key)
}
// GetUint retrieves a uint value from the State.
// It returns the uint and a boolean indicating successful type assertion.
func (s *State) GetUint(key string) (uint, bool) {
return GetState[uint](s, key)
}
// GetInt8 retrieves an int8 value from the State.
// It returns the int8 and a boolean indicating successful type assertion.
func (s *State) GetInt8(key string) (int8, bool) {
return GetState[int8](s, key)
}
// GetInt16 retrieves an int16 value from the State.
// It returns the int16 and a boolean indicating successful type assertion.
func (s *State) GetInt16(key string) (int16, bool) {
return GetState[int16](s, key)
}
// GetInt32 retrieves an int32 value from the State.
// It returns the int32 and a boolean indicating successful type assertion.
func (s *State) GetInt32(key string) (int32, bool) {
return GetState[int32](s, key)
}
// GetInt64 retrieves an int64 value from the State.
// It returns the int64 and a boolean indicating successful type assertion.
func (s *State) GetInt64(key string) (int64, bool) {
return GetState[int64](s, key)
}
// GetUint8 retrieves a uint8 value from the State.
// It returns the uint8 and a boolean indicating successful type assertion.
func (s *State) GetUint8(key string) (uint8, bool) {
return GetState[uint8](s, key)
}
// GetUint16 retrieves a uint16 value from the State.
// It returns the uint16 and a boolean indicating successful type assertion.
func (s *State) GetUint16(key string) (uint16, bool) {
return GetState[uint16](s, key)
}
// GetUint32 retrieves a uint32 value from the State.
// It returns the uint32 and a boolean indicating successful type assertion.
func (s *State) GetUint32(key string) (uint32, bool) {
return GetState[uint32](s, key)
}
// GetUint64 retrieves a uint64 value from the State.
// It returns the uint64 and a boolean indicating successful type assertion.
func (s *State) GetUint64(key string) (uint64, bool) {
return GetState[uint64](s, key)
}
// GetUintptr retrieves a uintptr value from the State.
// It returns the uintptr and a boolean indicating successful type assertion.
func (s *State) GetUintptr(key string) (uintptr, bool) {
return GetState[uintptr](s, key)
}
// GetFloat32 retrieves a float32 value from the State.
// It returns the float32 and a boolean indicating successful type assertion.
func (s *State) GetFloat32(key string) (float32, bool) {
return GetState[float32](s, key)
}
// GetComplex64 retrieves a complex64 value from the State.
// It returns the complex64 and a boolean indicating successful type assertion.
func (s *State) GetComplex64(key string) (complex64, bool) {
return GetState[complex64](s, key)
}
// GetComplex128 retrieves a complex128 value from the State.
// It returns the complex128 and a boolean indicating successful type assertion.
func (s *State) GetComplex128(key string) (complex128, bool) {
return GetState[complex128](s, key)
}
// serviceKey returns a key for a service in the State.
// A key is composed of the State's servicePrefix (hashed) and the hash of the service string.
// This way we can avoid collisions and have a unique key for each service.
func (s *State) serviceKey(key string) string {
// hash the service string to avoid collisions
return s.servicePrefix + hex.EncodeToString([]byte(key))
}
// setService sets a service in the State.
func (s *State) setService(srv Service) {
// Always prepend the service key with the servicesStateKey to avoid collisions
s.Set(s.serviceKey(srv.String()), srv)
}
// Delete removes a key-value pair from the State.
func (s *State) deleteService(srv Service) {
s.Delete(s.serviceKey(srv.String()))
}
// serviceKeys returns a slice containing all keys present for services in the application's State.
func (s *State) serviceKeys() []string {
// Pre-allocate with estimated capacity to reduce allocations
keys := make([]string, 0, 8)
s.dependencies.Range(func(key, _ any) bool {
keyStr, ok := key.(string)
if !ok {
return true
}
if !strings.HasPrefix(keyStr, s.servicePrefix) {
return true // Continue iterating if key doesn't have service prefix
}
keys = append(keys, keyStr)
return true
})
return keys
}
// Services returns a map containing all services present in the State.
// The key is the hash of the service String() value and the value is the service itself.
func (s *State) Services() map[string]Service {
keys := s.serviceKeys()
services := make(map[string]Service, len(keys))
for _, key := range keys {
services[key] = MustGetState[Service](s, key)
}
return services
}
// ServicesLen returns the number of keys for services in the State.
func (s *State) ServicesLen() int {
length := 0
s.dependencies.Range(func(key, _ any) bool {
if str, ok := key.(string); ok && strings.HasPrefix(str, s.servicePrefix) {
length++
}
return true
})
return length
}
// GetService returns a service present in the application's State.
func GetService[T Service](s *State, key string) (T, bool) {
srv, ok := GetState[T](s, s.serviceKey(key))
return srv, ok
}
// MustGetService returns a service present in the application's State.
// It panics if the service is not found.
func MustGetService[T Service](s *State, key string) T {
srv, ok := GetService[T](s, key)
if !ok {
panic("state: service not found!")
}
return srv
}
================================================
FILE: state_test.go
================================================
package fiber
import (
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func TestState_SetAndGet_WithApp(t *testing.T) {
t.Parallel()
// Create app
app := New()
// test setting and getting a value
app.State().Set("foo", "bar")
val, ok := app.State().Get("foo")
require.True(t, ok)
require.Equal(t, "bar", val)
// test key not found
_, ok = app.State().Get("unknown")
require.False(t, ok)
}
func TestState_SetAndGet(t *testing.T) {
t.Parallel()
st := newState()
// test setting and getting a value
st.Set("foo", "bar")
val, ok := st.Get("foo")
require.True(t, ok)
require.Equal(t, "bar", val)
// test key not found
_, ok = st.Get("unknown")
require.False(t, ok)
}
func TestState_GetString(t *testing.T) {
t.Parallel()
st := newState()
st.Set("str", "hello")
s, ok := st.GetString("str")
require.True(t, ok)
require.Equal(t, "hello", s)
// wrong type should return false
st.Set("num", 123)
s, ok = st.GetString("num")
require.False(t, ok)
require.Empty(t, s)
// missing key should return false
s, ok = st.GetString("missing")
require.False(t, ok)
require.Empty(t, s)
}
func TestState_GetInt(t *testing.T) {
t.Parallel()
st := newState()
st.Set("num", 456)
i, ok := st.GetInt("num")
require.True(t, ok)
require.Equal(t, 456, i)
// wrong type should return zero value
st.Set("str", "abc")
i, ok = st.GetInt("str")
require.False(t, ok)
require.Equal(t, 0, i)
// missing key should return zero value
i, ok = st.GetInt("missing")
require.False(t, ok)
require.Equal(t, 0, i)
}
func TestState_GetBool(t *testing.T) {
t.Parallel()
st := newState()
st.Set("flag", true)
b, ok := st.GetBool("flag")
require.True(t, ok)
require.True(t, b)
// wrong type
st.Set("num", 1)
b, ok = st.GetBool("num")
require.False(t, ok)
require.False(t, b)
// missing key should return false
b, ok = st.GetBool("missing")
require.False(t, ok)
require.False(t, b)
}
func TestState_GetFloat64(t *testing.T) {
t.Parallel()
st := newState()
st.Set("pi", 3.14)
f, ok := st.GetFloat64("pi")
require.True(t, ok)
require.InDelta(t, 3.14, f, 0.0001)
// wrong type should return zero value
st.Set("int", 10)
f, ok = st.GetFloat64("int")
require.False(t, ok)
require.InDelta(t, 0.0, f, 0.0001)
// missing key should return zero value
f, ok = st.GetFloat64("missing")
require.False(t, ok)
require.InDelta(t, 0.0, f, 0.0001)
}
func TestState_GetUint(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint", uint(100))
u, ok := st.GetUint("uint")
require.True(t, ok)
require.Equal(t, uint(100), u)
st.Set("wrong", "not uint")
u, ok = st.GetUint("wrong")
require.False(t, ok)
require.Equal(t, uint(0), u)
u, ok = st.GetUint("missing")
require.False(t, ok)
require.Equal(t, uint(0), u)
}
func TestState_GetInt8(t *testing.T) {
t.Parallel()
st := newState()
st.Set("int8", int8(10))
i, ok := st.GetInt8("int8")
require.True(t, ok)
require.Equal(t, int8(10), i)
st.Set("wrong", "not int8")
i, ok = st.GetInt8("wrong")
require.False(t, ok)
require.Equal(t, int8(0), i)
i, ok = st.GetInt8("missing")
require.False(t, ok)
require.Equal(t, int8(0), i)
}
func TestState_GetInt16(t *testing.T) {
t.Parallel()
st := newState()
st.Set("int16", int16(200))
i, ok := st.GetInt16("int16")
require.True(t, ok)
require.Equal(t, int16(200), i)
st.Set("wrong", "not int16")
i, ok = st.GetInt16("wrong")
require.False(t, ok)
require.Equal(t, int16(0), i)
i, ok = st.GetInt16("missing")
require.False(t, ok)
require.Equal(t, int16(0), i)
}
func TestState_GetInt32(t *testing.T) {
t.Parallel()
st := newState()
st.Set("int32", int32(3000))
i, ok := st.GetInt32("int32")
require.True(t, ok)
require.Equal(t, int32(3000), i)
st.Set("wrong", "not int32")
i, ok = st.GetInt32("wrong")
require.False(t, ok)
require.Equal(t, int32(0), i)
i, ok = st.GetInt32("missing")
require.False(t, ok)
require.Equal(t, int32(0), i)
}
func TestState_GetInt64(t *testing.T) {
t.Parallel()
st := newState()
st.Set("int64", int64(4000))
i, ok := st.GetInt64("int64")
require.True(t, ok)
require.Equal(t, int64(4000), i)
st.Set("wrong", "not int64")
i, ok = st.GetInt64("wrong")
require.False(t, ok)
require.Equal(t, int64(0), i)
i, ok = st.GetInt64("missing")
require.False(t, ok)
require.Equal(t, int64(0), i)
}
func TestState_GetUint8(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint8", uint8(20))
u, ok := st.GetUint8("uint8")
require.True(t, ok)
require.Equal(t, uint8(20), u)
st.Set("wrong", "not uint8")
u, ok = st.GetUint8("wrong")
require.False(t, ok)
require.Equal(t, uint8(0), u)
u, ok = st.GetUint8("missing")
require.False(t, ok)
require.Equal(t, uint8(0), u)
}
func TestState_GetUint16(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint16", uint16(300))
u, ok := st.GetUint16("uint16")
require.True(t, ok)
require.Equal(t, uint16(300), u)
st.Set("wrong", "not uint16")
u, ok = st.GetUint16("wrong")
require.False(t, ok)
require.Equal(t, uint16(0), u)
u, ok = st.GetUint16("missing")
require.False(t, ok)
require.Equal(t, uint16(0), u)
}
func TestState_GetUint32(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint32", uint32(400000))
u, ok := st.GetUint32("uint32")
require.True(t, ok)
require.Equal(t, uint32(400000), u)
st.Set("wrong", "not uint32")
u, ok = st.GetUint32("wrong")
require.False(t, ok)
require.Equal(t, uint32(0), u)
u, ok = st.GetUint32("missing")
require.False(t, ok)
require.Equal(t, uint32(0), u)
}
func TestState_GetUint64(t *testing.T) {
t.Parallel()
st := newState()
st.Set("uint64", uint64(5000000))
u, ok := st.GetUint64("uint64")
require.True(t, ok)
require.Equal(t, uint64(5000000), u)
st.Set("wrong", "not uint64")
u, ok = st.GetUint64("wrong")
require.False(t, ok)
require.Equal(t, uint64(0), u)
u, ok = st.GetUint64("missing")
require.False(t, ok)
require.Equal(t, uint64(0), u)
}
func TestState_GetUintptr(t *testing.T) {
t.Parallel()
st := newState()
var ptr uintptr = 12345
st.Set("uintptr", ptr)
u, ok := st.GetUintptr("uintptr")
require.True(t, ok)
require.Equal(t, ptr, u)
st.Set("wrong", "not uintptr")
u, ok = st.GetUintptr("wrong")
require.False(t, ok)
require.Equal(t, uintptr(0), u)
u, ok = st.GetUintptr("missing")
require.False(t, ok)
require.Equal(t, uintptr(0), u)
}
func TestState_GetFloat32(t *testing.T) {
t.Parallel()
st := newState()
st.Set("float32", float32(3.14))
f, ok := st.GetFloat32("float32")
require.True(t, ok)
require.InDelta(t, float32(3.14), f, 0.0001)
st.Set("wrong", "not float32")
f, ok = st.GetFloat32("wrong")
require.False(t, ok)
require.InDelta(t, float32(0), f, 0.0001)
f, ok = st.GetFloat32("missing")
require.False(t, ok)
require.InDelta(t, float32(0), f, 0.0001)
}
func TestState_GetComplex64(t *testing.T) {
t.Parallel()
st := newState()
var c complex64 = complex(2, 3)
st.Set("complex64", c)
cRes, ok := st.GetComplex64("complex64")
require.True(t, ok)
require.Equal(t, c, cRes)
st.Set("wrong", "not complex64")
cRes, ok = st.GetComplex64("wrong")
require.False(t, ok)
require.Equal(t, complex64(0), cRes)
cRes, ok = st.GetComplex64("missing")
require.False(t, ok)
require.Equal(t, complex64(0), cRes)
}
func TestState_GetComplex128(t *testing.T) {
t.Parallel()
st := newState()
c := complex(4, 5)
st.Set("complex128", c)
cRes, ok := st.GetComplex128("complex128")
require.True(t, ok)
require.Equal(t, c, cRes)
st.Set("wrong", "not complex128")
cRes, ok = st.GetComplex128("wrong")
require.False(t, ok)
require.Equal(t, complex128(0), cRes)
cRes, ok = st.GetComplex128("missing")
require.False(t, ok)
require.Equal(t, complex128(0), cRes)
}
func TestState_MustGet(t *testing.T) {
t.Parallel()
st := newState()
st.Set("exists", "value")
val := st.MustGet("exists")
require.Equal(t, "value", val)
// must-get on missing key should panic
require.Panics(t, func() {
_ = st.MustGet("missing")
})
}
func TestState_Has(t *testing.T) {
t.Parallel()
st := newState()
st.Set("key", "value")
require.True(t, st.Has("key"))
}
func TestState_Delete(t *testing.T) {
t.Parallel()
st := newState()
st.Set("key", "value")
st.Delete("key")
_, ok := st.Get("key")
require.False(t, ok)
}
func TestState_Reset(t *testing.T) {
t.Parallel()
st := newState()
st.Set("a", 1)
st.Set("b", 2)
st.Reset()
require.Equal(t, 0, st.Len())
require.Empty(t, st.Keys())
}
func TestState_Keys(t *testing.T) {
t.Parallel()
st := newState()
keys := []string{"one", "two", "three"}
for _, k := range keys {
st.Set(k, k)
}
returnedKeys := st.Keys()
require.ElementsMatch(t, keys, returnedKeys)
}
func TestState_Keys_SkipsNonStringKeys(t *testing.T) {
t.Parallel()
st := newState()
st.Set("one", "one")
st.Set("two", "two")
st.dependencies.Store(42, "value")
st.dependencies.Store(struct{}{}, "value")
returnedKeys := st.Keys()
require.ElementsMatch(t, []string{"one", "two"}, returnedKeys)
}
func TestState_Keys_SkipsNonStringKeys_WithMixedOrder(t *testing.T) {
t.Parallel()
st := newState()
st.Set("before", "value")
st.dependencies.Store(42, "value")
st.Set("after", "value")
returnedKeys := st.Keys()
require.ElementsMatch(t, []string{"before", "after"}, returnedKeys)
}
func TestState_Len(t *testing.T) {
t.Parallel()
st := newState()
require.Equal(t, 0, st.Len())
st.Set("a", "a")
require.Equal(t, 1, st.Len())
st.Set("b", "b")
require.Equal(t, 2, st.Len())
st.Delete("a")
require.Equal(t, 1, st.Len())
}
type testCase[T any] struct {
value any
expected T
name string
key string
ok bool
}
func runGenericTest[T any](t *testing.T, getter func(*State, string) (T, bool), tests []testCase[T]) {
t.Helper()
st := newState()
for _, tc := range tests {
st.Set(tc.key, tc.value)
got, ok := getter(st, tc.key)
require.Equal(t, tc.ok, ok, tc.name)
require.Equal(t, tc.expected, got, tc.name)
}
}
func TestState_GetGeneric(t *testing.T) {
t.Parallel()
runGenericTest(t, GetState[int], []testCase[int]{
{name: "int correct conversion", key: "num", value: 42, expected: 42, ok: true},
{name: "int wrong conversion from string", key: "str", value: "abc", expected: 0, ok: false},
})
runGenericTest(t, GetState[string], []testCase[string]{
{name: "string correct conversion", key: "strVal", value: "hello", expected: "hello", ok: true},
{name: "string wrong conversion from int", key: "intVal", value: 100, expected: "", ok: false},
})
runGenericTest(t, GetState[bool], []testCase[bool]{
{name: "bool correct conversion", key: "flag", value: true, expected: true, ok: true},
{name: "bool wrong conversion from int", key: "intFlag", value: 1, expected: false, ok: false},
})
runGenericTest(t, GetState[float64], []testCase[float64]{
{name: "float64 correct conversion", key: "pi", value: 3.14, expected: 3.14, ok: true},
{name: "float64 wrong conversion from int", key: "intVal", value: 10, expected: 0.0, ok: false},
})
}
func Test_MustGetStateGeneric(t *testing.T) {
t.Parallel()
st := newState()
st.Set("flag", true)
flag := MustGetState[bool](st, "flag")
require.True(t, flag)
// mismatched type should panic
require.Panics(t, func() {
_ = MustGetState[string](st, "flag")
})
// missing key should also panic
require.Panics(t, func() {
_ = MustGetState[string](st, "missing")
})
}
func Test_GetStateWithDefault(t *testing.T) {
t.Parallel()
st := newState()
st.Set("flag", true)
flag := GetStateWithDefault(st, "flag", false)
require.True(t, flag)
// mismatched type should return the default value
str := GetStateWithDefault(st, "flag", "default")
require.Equal(t, "default", str)
// missing key should return the default value
flag = GetStateWithDefault(st, "missing", false)
require.False(t, flag)
}
func TestState_Service(t *testing.T) {
t.Parallel()
srv1 := &mockService{name: "test1"}
// service 2 is using a very subtle name to check it is not picked up
srv2 := &mockService{name: "test1 "}
t.Run("set/get/ok", func(t *testing.T) {
t.Parallel()
st := newState()
st.setService(srv1)
got, ok := st.Get(st.serviceKey(srv1.String()))
require.True(t, ok)
require.Equal(t, srv1, got)
})
t.Run("set/get/ko", func(t *testing.T) {
t.Parallel()
st := newState()
st.setService(srv1)
koSrv := &mockService{name: "ko"}
got, ok := st.Get(st.serviceKey(koSrv.String()))
require.False(t, ok)
require.Nil(t, got)
})
t.Run("len", func(t *testing.T) {
t.Parallel()
t.Run("empty", func(t *testing.T) {
t.Parallel()
st := newState()
require.Equal(t, 0, st.Len())
require.Empty(t, st.serviceKeys())
})
t.Run("with-services", func(t *testing.T) {
t.Parallel()
st := newState()
st.setService(srv1)
st.setService(srv2)
require.Equal(t, 2, st.Len())
require.Equal(t, 2, st.ServicesLen())
})
t.Run("with-services/with-keys", func(t *testing.T) {
t.Parallel()
st := newState()
st.setService(srv1)
st.setService(srv2)
st.Set("key1", "value1")
st.Set("key2", "value2")
servicesLen := st.ServicesLen()
require.Equal(t, 4, st.Len())
require.Equal(t, 2, servicesLen)
})
})
t.Run("keys", func(t *testing.T) {
t.Run("empty", func(t *testing.T) {
t.Parallel()
st := newState()
// adding more keys to check they are not included
st.Set("key1", "value1")
st.Set("key2", "value2")
require.Empty(t, st.serviceKeys())
})
t.Run("with-services", func(t *testing.T) {
t.Parallel()
st := newState()
st.setService(srv1)
st.setService(srv2)
keys := st.serviceKeys()
require.Len(t, keys, 2)
require.Contains(t, keys, st.serviceKey(srv1.String()))
require.Contains(t, keys, st.serviceKey(srv2.String()))
})
t.Run("with-services/with-keys", func(t *testing.T) {
t.Parallel()
st := newState()
st.setService(srv1)
st.setService(srv2)
st.Set("key1", "value1")
st.Set("key2", "value2")
keys := st.serviceKeys()
require.Len(t, keys, 2)
require.Contains(t, keys, st.serviceKey(srv1.String()))
require.Contains(t, keys, st.serviceKey(srv2.String()))
require.NotContains(t, keys, "key1")
require.NotContains(t, keys, "key2")
})
t.Run("with-non-string-key", func(t *testing.T) {
t.Parallel()
st := newState()
st.setService(srv1)
st.setService(srv2)
st.dependencies.Store(42, "value")
st.dependencies.Store(struct{}{}, "value")
keys := st.serviceKeys()
require.Len(t, keys, 2)
require.Contains(t, keys, st.serviceKey(srv1.String()))
require.Contains(t, keys, st.serviceKey(srv2.String()))
})
t.Run("with-non-service-keys", func(t *testing.T) {
t.Parallel()
st := newState()
st.setService(srv1)
st.setService(srv2)
st.Set("other", "value")
st.dependencies.Store(42, "value")
keys := st.serviceKeys()
require.Len(t, keys, 2)
require.Contains(t, keys, st.serviceKey(srv1.String()))
require.Contains(t, keys, st.serviceKey(srv2.String()))
})
})
t.Run("delete", func(t *testing.T) {
t.Parallel()
t.Run("ok", func(t *testing.T) {
t.Parallel()
st := newState()
st.setService(srv1)
st.deleteService(srv1)
_, ok := st.Get(st.serviceKey(srv1.String()))
require.False(t, ok)
})
t.Run("missing", func(t *testing.T) {
t.Parallel()
st := newState()
st.setService(srv1)
st.deleteService(srv2)
_, ok := st.Get(st.serviceKey(srv1.String()))
require.True(t, ok)
_, ok = st.Get(st.serviceKey(srv2.String()))
require.False(t, ok)
})
})
}
func TestState_GetService(t *testing.T) {
t.Parallel()
t.Run("ok", func(t *testing.T) {
t.Parallel()
srv1 := &mockService{name: "test1"}
st := newState()
st.setService(srv1)
got, ok := GetService[*mockService](st, srv1.String())
require.True(t, ok)
require.Equal(t, srv1, got)
})
t.Run("ko", func(t *testing.T) {
t.Parallel()
srv1 := &mockService{name: "test1"}
st := newState()
got, ok := GetService[*mockService](st, srv1.String())
require.False(t, ok)
require.Nil(t, got)
})
}
func TestState_MustGetService(t *testing.T) {
t.Parallel()
t.Run("ok", func(t *testing.T) {
t.Parallel()
srv1 := &mockService{name: "test1"}
st := newState()
st.setService(srv1)
got := MustGetService[*mockService](st, srv1.String())
require.Equal(t, srv1, got)
})
t.Run("panics", func(t *testing.T) {
t.Parallel()
srv1 := &mockService{name: "test1"}
st := newState()
require.Panics(t, func() {
_ = MustGetService[*mockService](st, srv1.String())
})
})
}
func BenchmarkState_Set(b *testing.B) {
b.ReportAllocs()
st := newState()
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
}
func BenchmarkState_Get(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// prepopulate the state
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.Get(key)
}
}
func BenchmarkState_GetString(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// prepopulate the state
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, strconv.Itoa(i))
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetString(key)
}
}
func BenchmarkState_GetInt(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// prepopulate the state
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetInt(key)
}
}
func BenchmarkState_GetBool(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// prepopulate the state
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, i%2 == 0)
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetBool(key)
}
}
func BenchmarkState_GetFloat64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// prepopulate the state
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, float64(i))
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetFloat64(key)
}
}
func BenchmarkState_MustGet(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// prepopulate the state
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.MustGet(key)
}
}
func BenchmarkState_GetStateGeneric(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// prepopulate the state
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
GetState[int](st, key)
}
}
func BenchmarkState_MustGetStateGeneric(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// prepopulate the state
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
MustGetState[int](st, key)
}
}
func BenchmarkState_GetStateWithDefault(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// prepopulate the state
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, i)
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
GetStateWithDefault(st, key, 0)
}
}
func BenchmarkState_Has(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// prepopulate the state
for i := range n {
st.Set("key"+strconv.Itoa(i), i)
}
i := 0
for b.Loop() {
i++
st.Has("key" + strconv.Itoa(i%n))
}
}
func BenchmarkState_Delete(b *testing.B) {
b.ReportAllocs()
for b.Loop() {
st := newState()
st.Set("a", 1)
st.Delete("a")
}
}
func BenchmarkState_Reset(b *testing.B) {
b.ReportAllocs()
for b.Loop() {
st := newState()
// add a fixed number of keys before clearing
for j := range 100 {
st.Set("key"+strconv.Itoa(j), j)
}
st.Reset()
}
}
func BenchmarkState_Keys(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
for i := range n {
st.Set("key"+strconv.Itoa(i), i)
}
for b.Loop() {
_ = st.Keys()
}
}
func BenchmarkState_Len(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
for i := range n {
st.Set("key"+strconv.Itoa(i), i)
}
for b.Loop() {
_ = st.Len()
}
}
func BenchmarkState_GetUint(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with uint values.
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, uint(i)) //nolint:gosec // G115 - test values are small
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetUint(key)
}
}
func BenchmarkState_GetInt8(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with int8 values (using modulo to stay in range).
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, int8(i%128)) //nolint:gosec // G115 - modulo keeps value in range
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetInt8(key)
}
}
func BenchmarkState_GetInt16(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with int16 values.
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, int16(i)) //nolint:gosec // G115 - test values are small
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetInt16(key)
}
}
func BenchmarkState_GetInt32(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with int32 values.
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, int32(i)) //nolint:gosec // G115 - test values are small
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetInt32(key)
}
}
func BenchmarkState_GetInt64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with int64 values.
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, int64(i))
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetInt64(key)
}
}
func BenchmarkState_GetUint8(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with uint8 values.
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, uint8(i%256)) //nolint:gosec // G115 - modulo keeps value in range
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetUint8(key)
}
}
func BenchmarkState_GetUint16(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with uint16 values.
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, uint16(i)) //nolint:gosec // G115 - test values are small
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetUint16(key)
}
}
func BenchmarkState_GetUint32(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with uint32 values.
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, uint32(i)) //nolint:gosec // G115 - test values are small
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetUint32(key)
}
}
func BenchmarkState_GetUint64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with uint64 values.
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, uint64(i)) //nolint:gosec // G115 - test values are small
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetUint64(key)
}
}
func BenchmarkState_GetUintptr(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with uintptr values.
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, uintptr(i))
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetUintptr(key)
}
}
func BenchmarkState_GetFloat32(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with float32 values.
for i := range n {
key := "key" + strconv.Itoa(i)
st.Set(key, float32(i))
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetFloat32(key)
}
}
func BenchmarkState_GetComplex64(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with complex64 values.
for i := range n {
key := "key" + strconv.Itoa(i)
// Create a complex64 value with both real and imaginary parts.
st.Set(key, complex(float32(i), float32(i)))
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetComplex64(key)
}
}
func BenchmarkState_GetComplex128(b *testing.B) {
b.ReportAllocs()
st := newState()
n := 1000
// Prepopulate the state with complex128 values.
for i := range n {
key := "key" + strconv.Itoa(i)
// Create a complex128 value with both real and imaginary parts.
st.Set(key, complex(float64(i), float64(i)))
}
i := 0
for b.Loop() {
i++
key := "key" + strconv.Itoa(i%n)
st.GetComplex128(key)
}
}
func BenchmarkState_GetService(b *testing.B) {
b.ReportAllocs()
st := newState()
srv := &mockService{name: "benchService"}
st.setService(srv)
for b.Loop() {
_, _ = GetService[*mockService](st, srv.String())
}
}
func BenchmarkState_MustGetService(b *testing.B) {
b.ReportAllocs()
st := newState()
srv := &mockService{name: "benchService"}
st.setService(srv)
for b.Loop() {
_ = MustGetService[*mockService](st, srv.String())
}
}
================================================
FILE: storage_interface.go
================================================
package fiber
import (
"context"
"time"
)
// Storage interface for communicating with different database/key-value
// providers
type Storage interface {
// GetWithContext gets the value for the given key with a context.
// `nil, nil` is returned when the key does not exist
GetWithContext(ctx context.Context, key string) ([]byte, error)
// Get gets the value for the given key.
// `nil, nil` is returned when the key does not exist
Get(key string) ([]byte, error)
// SetWithContext stores the given value for the given key
// with an expiration value, 0 means no expiration.
// Empty key or value will be ignored without an error.
SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error
// Set stores the given value for the given key along
// with an expiration value, 0 means no expiration.
// Empty key or value will be ignored without an error.
Set(key string, val []byte, exp time.Duration) error
// DeleteWithContext deletes the value for the given key with a context.
// It returns no error if the storage does not contain the key,
DeleteWithContext(ctx context.Context, key string) error
// Delete deletes the value for the given key.
// It returns no error if the storage does not contain the key,
Delete(key string) error
// ResetWithContext resets the storage and deletes all keys with a context.
ResetWithContext(ctx context.Context) error
// Reset resets the storage and delete all keys.
Reset() error
// Close closes the storage and will stop any running garbage
// collectors and open connections.
Close() error
}